corrscope/ovgenpy/triggers.py

237 wiersze
6.7 KiB
Python

from abc import ABC, abstractmethod
from typing import NamedTuple, List, Dict, Any, TYPE_CHECKING
from dataclasses import dataclass
import numpy as np
from matplotlib import pyplot as plt
from scipy import signal
from ovgenpy.renderer import MatplotlibRenderer, RendererConfig
if TYPE_CHECKING:
from ovgenpy.wave import Wave
class Trigger(ABC):
def __init__(self, wave: 'Wave', scan_nsamp: int):
self._wave = wave
self._scan_nsamp = scan_nsamp
@abstractmethod
def get_trigger(self, offset: int) -> int:
"""
:param offset: sample index
:return: new sample index, corresponding to rising edge
"""
...
class TriggerConfig:
# NamedTuple inheritance does not work. Mark children @dataclass instead.
# https://github.com/python/typing/issues/427
def __call__(self, wave: 'Wave', scan_nsamp: int):
raise NotImplementedError
SHOW_TRIGGER = False
SHOW_TRIGGER2 = False
def lerp(x: np.ndarray, y: np.ndarray, a: float):
return x * (1 - a) + y * a
class Dummy:
def __getattr__(self, item):
return self
def __call__(self, *args, **kwargs):
return self
def plots(nplot):
for i in range(nplot):
if SHOW_TRIGGER2:
yield plt.subplot(nplot, 1, i+1)
else:
yield Dummy()
class CorrelationTrigger(Trigger):
MIN_AMPLITUDE = 0.01
@dataclass
class Config(TriggerConfig):
# get_trigger
trigger_strength: float
# _update_buffer
responsiveness: float
falloff_width: float
def __call__(self, wave: 'Wave', scan_nsamp: int):
return CorrelationTrigger(wave, scan_nsamp, cfg=self)
def __init__(self, wave: 'Wave', scan_nsamp: int, cfg: Config):
"""
Correlation-based trigger which looks at a window of `scan_nsamp` samples.
it's complicated
:param wave: Wave file
:param scan_nsamp: Number of samples used to align adjacent frames
:param cfg: Correlation config
"""
Trigger.__init__(self, wave, scan_nsamp)
self._buffer_nsamp = self._scan_nsamp
# Correlation config
self.cfg = cfg
# Create correlation buffer (containing a series of old data)
self._buffer = np.zeros(scan_nsamp)
if SHOW_TRIGGER:
self._trigger_renderer = TriggerRenderer(self)
def get_trigger(self, offset: int) -> int:
"""
:param offset: sample index
:return: new sample index, corresponding to rising edge
"""
trigger_strength = self.cfg.trigger_strength
data = self._wave.get_around(offset, self._buffer_nsamp)
N = len(data)
ps = plots(4)
next(ps).plot(data)
# Add "step function" to correlation buffer
halfN = N // 2
wave_period = get_period(data)
window = signal.gaussian(N, std = halfN // 3)
step = np.empty(N)
step[:halfN] = -trigger_strength / 2
step[halfN:] = trigger_strength / 2
step *= window
prev_buffer = self._buffer + step
next(ps).plot(prev_buffer)
# Find optimal offset (within ±N//4)
delta = N-1
radius = N//4
# Calculate correlation
"""
If offset < optimal, we need to `offset += positive`.
- The peak will appear near the right of `data`.
Either we must slide prev_buffer to the right:
- correlate(data, prev_buffer)
- trigger = offset + peak_offset
Or we must slide data to the left (by sliding offset to the right):
- correlate(prev_buffer, data)
- trigger = offset - peak_offset
"""
corr = signal.correlate(data, prev_buffer)
assert len(corr) == 2*N - 1
next(ps).plot(corr)
corr = corr[delta-radius : delta+radius+1]
delta = radius
next(ps).plot(corr)
# argmax(corr) == delta + peak_offset == (data >> peak_offset)
# peak_offset == argmax(corr) - delta
peak_offset = np.argmax(corr) - delta # type: int
trigger = offset + peak_offset
# Update correlation buffer (distinct from visible area)
aligned = self._wave.get_around(trigger, self._buffer_nsamp)
self._update_buffer(aligned)
if SHOW_TRIGGER2:
plt.show()
return trigger
def _update_buffer(self, data: np.ndarray) -> None:
"""
Update self._buffer by adding `data` and a step function.
Data is reshaped to taper away from the center.
:param data: Wave data. WILL BE MODIFIED.
"""
falloff_width = self.cfg.falloff_width
responsiveness = self.cfg.responsiveness
N = len(data)
if N != self._buffer_nsamp:
raise ValueError(f'invalid data length {len(data)} does not match '
f'CorrelationTrigger {self._buffer_nsamp}')
# New waveform
self._normalize_buffer(data)
wave_period = get_period(data)
window = signal.gaussian(N, std = wave_period * falloff_width)
data *= window
# Old buffer
self._normalize_buffer(self._buffer)
self._buffer = lerp(self._buffer, data, responsiveness)
if SHOW_TRIGGER:
self._trigger_renderer.render_frame()
# const method
def _normalize_buffer(self, data: np.ndarray) -> None:
"""
Rescales `data` in-place.
"""
peak = np.amax(abs(data))
data /= max(peak, self.MIN_AMPLITUDE)
def get_period(data: np.ndarray) -> int:
"""
Use autocorrelation to estimate the period of a signal.
Loosely inspired by https://github.com/endolith/waveform_analysis
"""
corr = signal.correlate(data, data, mode='full', method='fft')
corr = corr[len(corr) // 2:]
# Remove the zero-correlation peak
zero_crossings = np.where(corr < 0)[0]
if len(zero_crossings) == 0:
# This can happen given an array of all zeros. Anything else?
return len(data)
crossX = zero_crossings[0]
peakX = crossX + np.argmax(corr[crossX:])
return peakX
if SHOW_TRIGGER:
class TriggerRenderer(MatplotlibRenderer):
# TODO swappable GraphRenderer class shouldn't depend on waves
# probably don't need to debug multiple triggers
def __init__(self, trigger: CorrelationTrigger):
self.trigger = trigger
cfg = RendererConfig(
640, 360, trigger._buffer_nsamp, rows_first=False, ncols=1
)
super().__init__(cfg, [None])
def render_frame(self) -> None:
idx = 0
# Draw trigger buffer data
line = self.lines[idx]
data = self.trigger._buffer
line.set_ydata(data)