diff --git a/ovgenpy/channel.py b/ovgenpy/channel.py index ae9f137..38c325f 100644 --- a/ovgenpy/channel.py +++ b/ovgenpy/channel.py @@ -41,7 +41,7 @@ class Channel: ) self.trigger = tcfg( wave=self.wave, - scan_nsamp=trigger_nsamp, # TODO rename to trigger_nsamp - # FIXME self.trigger_subsampling + nsamp=trigger_nsamp, + subsampling=ovgen_cfg.subsampling * self.trigger_subsampling ) diff --git a/ovgenpy/triggers.py b/ovgenpy/triggers.py index ba7f5a9..11eb356 100644 --- a/ovgenpy/triggers.py +++ b/ovgenpy/triggers.py @@ -1,6 +1,5 @@ -import weakref from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Type import numpy as np from scipy import signal @@ -14,10 +13,31 @@ if TYPE_CHECKING: from ovgenpy.wave import Wave +class ITriggerConfig: + cls: Type['Trigger'] + + def __call__(self, wave: 'Wave', nsamp: int, subsampling: int): + return self.cls(wave, cfg=self, nsamp=nsamp, subsampling=subsampling) + + +def register_trigger(config_t: Type[ITriggerConfig]): + """ @register_trigger(FooTriggerConfig) + def FooTrigger(): ... + """ + def inner(trigger_t: Type[Trigger]): + config_t.cls = trigger_t + return trigger_t + + return inner + + class Trigger(ABC): - def __init__(self, wave: 'Wave', scan_nsamp: int): - self._wave: Wave = weakref.proxy(wave) - self._scan_nsamp = scan_nsamp + def __init__(self, wave: 'Wave', cfg: ITriggerConfig, nsamp: int, subsampling: int): + self.cfg = cfg + self._wave = wave + + self._trigger_nsamp = nsamp + self._trigger_subsampling = subsampling @abstractmethod def get_trigger(self, index: int) -> int: @@ -28,14 +48,6 @@ class Trigger(ABC): ... -class ITriggerConfig: - def __call__(self, wave: 'Wave', scan_nsamp: int): - # idea: __call__ return self.cls(wave, scan_nsamp, cfg=self) - # problem: cannot reference XTrigger from within XTrigger - # solution: @register_trigger(XTriggerCfg) - raise NotImplementedError - - def lerp(x: np.ndarray, y: np.ndarray, a: float): return x * (1 - a) + y * a @@ -50,36 +62,31 @@ class CorrelationTriggerConfig(ITriggerConfig): responsiveness: float falloff_width: float - def __call__(self, wave: 'Wave', scan_nsamp: int): - return CorrelationTrigger(wave, scan_nsamp, cfg=self) - +@register_trigger(CorrelationTriggerConfig) class CorrelationTrigger(Trigger): MIN_AMPLITUDE = 0.01 - # get_trigger postprocessing: self._zero_trigger ZERO_CROSSING_SCAN = 256 + cfg: CorrelationTriggerConfig - def __init__(self, wave: 'Wave', scan_nsamp: int, cfg: CorrelationTriggerConfig): + def __init__(self, *args, **kwargs): """ - Correlation-based trigger which looks at a window of `scan_nsamp` samples. - + Correlation-based trigger which looks at a window of `trigger_nsamp` samples. it's complicated - - :param wave: Wave file - :param scan_nsamp: Number of samples used to align adjacent frames - :param cfg: Correlation config """ - Trigger.__init__(self, wave, scan_nsamp) - self._buffer_nsamp = self._scan_nsamp - - # Correlation config - self.cfg = cfg + Trigger.__init__(self, *args, **kwargs) + self._buffer_nsamp = self._trigger_nsamp # Create correlation buffer (containing a series of old data) - self._buffer = np.zeros(scan_nsamp, dtype=FLOAT) # type: np.ndarray[FLOAT] + self._buffer = np.zeros(self._buffer_nsamp, dtype=FLOAT) # type: np.ndarray[FLOAT] # Create zero crossing trigger, for postprocessing results - self._zero_trigger = ZeroCrossingTrigger(wave, self.ZERO_CROSSING_SCAN) + self._zero_trigger = ZeroCrossingTrigger( + self._wave, + ITriggerConfig(), + nsamp=self.ZERO_CROSSING_SCAN, + subsampling=1, + ) def get_trigger(self, index: int) -> int: """ @@ -199,11 +206,8 @@ def get_period(data: np.ndarray) -> int: class ZeroCrossingTrigger(Trigger): - def __init__(self, wave: 'Wave', scan_nsamp: int): - super().__init__(wave, scan_nsamp) - def get_trigger(self, index: int): - scan_nsamp = self._scan_nsamp + trigger_nsamp = self._trigger_nsamp if not 0 <= index < self._wave.nsamp: return index @@ -219,7 +223,7 @@ class ZeroCrossingTrigger(Trigger): else: # self._wave[sample] == 0 return index + 1 - data = self._wave[index : index + (direction * scan_nsamp) : direction] + data = self._wave[index : index + (direction * trigger_nsamp) : direction] intercepts = find(data, test) try: (delta,), value = next(intercepts) diff --git a/ovgenpy/wave.py b/ovgenpy/wave.py index e87747a..56bc236 100644 --- a/ovgenpy/wave.py +++ b/ovgenpy/wave.py @@ -5,7 +5,7 @@ from ovgenpy.config import dataclass from scipy.io import wavfile -# Internal class, not exposed via YAML (TODO replace with ChannelConfig?) +# Internal class, not exposed via YAML @dataclass class _WaveConfig: amplification: float = 1 diff --git a/tests/test_trigger.py b/tests/test_trigger.py index 9700285..6962fed 100644 --- a/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -23,14 +23,14 @@ def cfg(request): ) -def test_trigger(cfg): +def test_trigger(cfg: CorrelationTriggerConfig): # wave = Wave(None, 'tests/sine440.wav') wave = Wave(None, 'tests/impulse24000.wav') iters = 5 plot = False x = 24000 - 500 - trigger = cfg(wave, 4000) + trigger = cfg(wave, 4000, subsampling=1) if plot: BIG = 0.95