corrscope/ovgenpy/triggers.py

376 wiersze
11 KiB
Python

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Type, Tuple
import numpy as np
from scipy import signal
from scipy.signal import windows
from ovgenpy.config import register_config, OvgenError, Alias
from ovgenpy.util import find
from ovgenpy.utils.windows import midpad, leftpad
from ovgenpy.wave import FLOAT
if TYPE_CHECKING:
from ovgenpy.wave import Wave
# Abstract classes
class ITriggerConfig:
cls: Type['Trigger']
def __call__(self, wave: 'Wave', tsamp: int, subsampling: int, fps: float):
return self.cls(wave, cfg=self, tsamp=tsamp, subsampling=subsampling, fps=fps)
def register_trigger(config_t: Type[ITriggerConfig]):
""" @register_trigger(FooTriggerConfig)
def FooTrigger(): ...
"""
def inner(trigger_t: Type[Trigger]):
config_t.cls = trigger_t
return trigger_t
return inner
class Trigger(ABC):
def __init__(self, wave: 'Wave', cfg: ITriggerConfig, tsamp: int, subsampling: int,
fps: float):
self.cfg = cfg
self._wave = wave
self._tsamp = tsamp
self._subsampling = subsampling
self._fps = fps
frame_dur = 1 / fps
# Subsamples per frame
self._tsamp_frame = self.time2tsamp(frame_dur)
# Samples per frame
self._real_samp_frame = round(frame_dur * self._wave.smp_s)
def time2tsamp(self, time: float):
return round(time * self._wave.smp_s / self._subsampling)
@abstractmethod
def get_trigger(self, index: int) -> int:
"""
:param index: sample index
:return: new sample index, corresponding to rising edge
"""
...
# CorrelationTrigger
@register_config(always_dump='''
use_edge_trigger
edge_strength
responsiveness
buffer_falloff
''')
class CorrelationTriggerConfig(ITriggerConfig):
# get_trigger
use_edge_trigger: bool = True
edge_strength: float = 10.0
trigger_diameter: float = 0.5
trigger_falloff: Tuple[float, float] = (4.0, 1.0)
lag_prevention: float = 0.25
# _update_buffer
responsiveness: float = 0.1
buffer_falloff: float = 0.5
# region Legacy Aliases
trigger_strength = Alias('edge_strength')
falloff_width = Alias('buffer_falloff')
# endregion
def __post_init__(self):
self._validate_param('lag_prevention', 0, 1)
self._validate_param('responsiveness', 0, 1)
# TODO trigger_falloff >= 0
self._validate_param('buffer_falloff', 0, np.inf)
def _validate_param(self, key: str, begin, end):
value = getattr(self, key)
if not begin <= value <= end:
raise ValueError(
f'Invalid {key}={value} (should be within [{begin}, {end}])')
@register_trigger(CorrelationTriggerConfig)
class CorrelationTrigger(Trigger):
MIN_AMPLITUDE = 0.01
ZERO_CROSSING_SCAN = 256
cfg: CorrelationTriggerConfig
def __init__(self, *args, **kwargs):
"""
Correlation-based trigger which looks at a window of `trigger_tsamp` samples.
it's complicated
"""
Trigger.__init__(self, *args, **kwargs)
self._buffer_nsamp = self._tsamp
# Create correlation buffer (containing a series of old data)
self._buffer = np.zeros(self._buffer_nsamp, dtype=FLOAT) # type: np.ndarray[FLOAT]
# Create zero crossing trigger, for postprocessing results
self._zero_trigger = ZeroCrossingTrigger(
self._wave,
ITriggerConfig(),
tsamp=self.ZERO_CROSSING_SCAN,
subsampling=1,
fps=self._fps
)
# Precompute edge trigger step
self._windowed_step = self._calc_step()
# Input data taper (zeroes out all data older than 1 frame old)
self._data_taper = self._calc_data_taper() # Rejected idea: right cosine taper
# For debug output
self.save_window = False
def _calc_step(self):
""" Step function used for approximate edge triggering. """
edge_strength = self.cfg.edge_strength
N = self._buffer_nsamp
halfN = N // 2
step = np.empty(N, dtype=FLOAT) # type: np.ndarray[FLOAT]
step[:halfN] = -edge_strength / 2
step[halfN:] = edge_strength / 2
step *= windows.gaussian(N, std=halfN / 3)
return step
def _calc_data_taper(self):
""" Input data window. Zeroes out all data older than 1 frame old.
See https://github.com/nyanpasu64/ovgenpy/wiki/Correlation-Trigger
"""
N = self._buffer_nsamp
halfN = N // 2
# To avoid cutting off data, use a narrow transition zone (invariant to
# subsampling).
transition_nsamp = round(self._real_samp_frame * self.cfg.lag_prevention)
tsamp_frame = self._tsamp_frame
# Left half of a Hann cosine taper
# Width = min(subsampling*frame * lag_prevention, 1 frame)
width = min(transition_nsamp, tsamp_frame)
taper = windows.hann(width * 2)[:width]
# Right-pad taper to 1 frame long
if width < tsamp_frame:
taper = np.pad(taper, (0, tsamp_frame - width), 'constant',
constant_values=1)
assert len(taper) == tsamp_frame
# Reshape taper to left `halfN` of data_window (right-aligned).
taper = leftpad(taper, halfN)
# Generate left half-taper to prevent correlating with 1-frame-old data.
data_window = np.ones(N)
data_window[:halfN] = np.minimum(data_window[:halfN], taper)
return data_window
def get_trigger(self, index: int) -> int:
"""
:param index: sample index
:return: new sample index, corresponding to rising edge
"""
N = self._buffer_nsamp
use_edge_trigger = self.cfg.use_edge_trigger
# Get data
data = self._wave.get_around(index, N, self._subsampling)
# Window data
period = get_period(data)
diameter, falloff = [round(period * x) for x in self.cfg.trigger_falloff]
falloff_window = cosine_flat(N, diameter, falloff)
window = np.minimum(falloff_window, self._data_taper)
data *= window
if self.save_window:
self._prev_window = window
# prev_buffer
prev_buffer = self._windowed_step + self._buffer
# Calculate correlation
"""
If offset < optimal, we need to `offset += positive`.
- The peak will appear near the right of `data`.
Either we must slide prev_buffer to the right:
- correlate(data, prev_buffer)
- trigger = offset + peak_offset
Or we must slide data to the left (by sliding offset to the right):
- correlate(prev_buffer, data)
- trigger = offset - peak_offset
"""
corr = signal.correlate(data, prev_buffer)
assert len(corr) == 2*N - 1
# Find optimal offset (within ±N//4)
mid = N-1
radius = round(N * self.cfg.trigger_diameter / 2)
left = mid - radius
right = mid + radius + 1
corr = corr[left:right]
mid = mid - left
# argmax(corr) == mid + peak_offset == (data >> peak_offset)
# peak_offset == argmax(corr) - mid
peak_offset = np.argmax(corr) - mid # type: int
trigger = index + (self._subsampling * peak_offset)
# Update correlation buffer (distinct from visible area)
aligned = self._wave.get_around(trigger, self._buffer_nsamp, self._subsampling)
self._update_buffer(aligned, period)
if use_edge_trigger:
return self._zero_trigger.get_trigger(trigger)
else:
return trigger
def _update_buffer(self, data: np.ndarray, wave_period: int) -> None:
"""
Update self._buffer by adding `data` and a step function.
Data is reshaped to taper away from the center.
:param data: Wave data. WILL BE MODIFIED.
"""
buffer_falloff = self.cfg.buffer_falloff
responsiveness = self.cfg.responsiveness
N = len(data)
if N != self._buffer_nsamp:
raise ValueError(f'invalid data length {len(data)} does not match '
f'CorrelationTrigger {self._buffer_nsamp}')
# New waveform
self._normalize_buffer(data)
window = windows.gaussian(N, std = wave_period * buffer_falloff)
data *= window
# Old buffer
self._normalize_buffer(self._buffer)
self._buffer = lerp(self._buffer, data, responsiveness)
# const method
def _normalize_buffer(self, data: np.ndarray) -> None:
"""
Rescales `data` in-place.
"""
peak = np.amax(abs(data))
data /= max(peak, self.MIN_AMPLITUDE)
def get_period(data: np.ndarray) -> int:
"""
Use autocorrelation to estimate the period of a signal.
Loosely inspired by https://github.com/endolith/waveform_analysis
"""
corr = signal.correlate(data, data, mode='full', method='fft')
corr = corr[len(corr) // 2:]
# Remove the zero-correlation peak
zero_crossings = np.where(corr < 0)[0]
if len(zero_crossings) == 0:
# This can happen given an array of all zeros. Anything else?
return len(data)
crossX = zero_crossings[0]
peakX = crossX + np.argmax(corr[crossX:])
return peakX
def cosine_flat(n: int, diameter: int, falloff: int):
cosine = windows.hann(falloff * 2)
left, right = cosine[:falloff], cosine[falloff:]
window = np.concatenate([left, np.ones(diameter), right])
padded = midpad(window, n)
return padded
def lerp(x: np.ndarray, y: np.ndarray, a: float):
return x * (1 - a) + y * a
# ZeroCrossingTrigger
class ZeroCrossingTrigger(Trigger):
# TODO support subsampling
def get_trigger(self, index: int):
if self._subsampling != 1:
raise OvgenError(
f'ZeroCrossingTrigger with subsampling != 1 is not implemented '
f'(supplied {self._subsampling})')
tsamp = self._tsamp
if not 0 <= index < self._wave.nsamp:
return index
if self._wave[index] < 0:
direction = 1
test = lambda a: a >= 0
elif self._wave[index] > 0:
direction = -1
test = lambda a: a <= 0
else: # self._wave[sample] == 0
return index + 1
data = self._wave[index : index + (direction * tsamp) : direction]
intercepts = find(data, test)
try:
(delta,), value = next(intercepts)
return index + (delta * direction) + int(value <= 0)
except StopIteration: # No zero-intercepts
return index
# noinspection PyUnreachableCode
"""
`value <= 0` produces poor results on on sine waves, since it erroneously
increments the exact idx of the zero-crossing sample.
`value < 0` produces poor results on impulse24000, since idx = 23999 which
doesn't match CorrelationTrigger. (scans left looking for a zero-crossing)
CorrelationTrigger tries to maximize @trigger - @(trigger-1). I think always
incrementing zeros (impulse24000 = 24000) is acceptable.
- To be consistent, we should increment zeros whenever we *start* there.
"""
# NullTrigger
@register_config
class NullTriggerConfig(ITriggerConfig):
pass
@register_trigger(NullTriggerConfig)
class NullTrigger(Trigger):
def get_trigger(self, index: int) -> int:
return index