kopia lustrzana https://github.com/corrscope/corrscope
Add ITriggerConfig.post and Trigger.post for trigger chaining
One would think `use_edge_trigger` should be an InitVar. But it breaks test_trigger_subsampling() which is parameterized on `cfg.use_edge_trigger` (and I want to keep it that way, not peek at `cfg.post`). I also copied test_trigger_subsampling to test_post_trigger_subsampling, to test `cfg.post`.pull/357/head
rodzic
6c4e128031
commit
cc7b9b35ac
|
@ -1,5 +1,6 @@
|
|||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Type, Tuple, Optional
|
||||
from typing import TYPE_CHECKING, Type, Tuple, Optional, ClassVar
|
||||
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
|
@ -7,7 +8,7 @@ from scipy.signal import windows
|
|||
|
||||
from ovgenpy.config import register_config, OvgenError, Alias
|
||||
from ovgenpy.util import find
|
||||
from ovgenpy.utils.keyword_dataclasses import dataclass
|
||||
from ovgenpy.utils.keyword_dataclasses import dataclass, InitVar, field
|
||||
from ovgenpy.utils.windows import midpad, leftpad
|
||||
from ovgenpy.wave import FLOAT
|
||||
|
||||
|
@ -17,10 +18,15 @@ if TYPE_CHECKING:
|
|||
|
||||
# Abstract classes
|
||||
|
||||
@dataclass
|
||||
class ITriggerConfig:
|
||||
cls: Type['Trigger']
|
||||
cls: ClassVar[Type['Trigger']]
|
||||
|
||||
def __call__(self, wave: 'Wave', tsamp: int, subsampling: int, fps: float):
|
||||
# Optional trigger for postprocessing
|
||||
post: 'ITriggerConfig' = None
|
||||
|
||||
def __call__(self, wave: 'Wave', tsamp: int, subsampling: int, fps: float) \
|
||||
-> 'Trigger':
|
||||
return self.cls(wave, cfg=self, tsamp=tsamp, subsampling=subsampling, fps=fps)
|
||||
|
||||
|
||||
|
@ -36,6 +42,8 @@ def register_trigger(config_t: Type[ITriggerConfig]):
|
|||
|
||||
|
||||
class Trigger(ABC):
|
||||
POST_PROCESSING_NSAMP = 256
|
||||
|
||||
def __init__(self, wave: 'Wave', cfg: ITriggerConfig, tsamp: int, subsampling: int,
|
||||
fps: float):
|
||||
self.cfg = cfg
|
||||
|
@ -51,6 +59,13 @@ class Trigger(ABC):
|
|||
# Samples per frame
|
||||
self._real_samp_frame = round(frame_dur * self._wave.smp_s)
|
||||
|
||||
if cfg.post:
|
||||
# Create a post-processing trigger, with narrow nsamp and no subsampling.
|
||||
# This improves speed and precision.
|
||||
self.post = cfg.post(wave, self.POST_PROCESSING_NSAMP, 1, fps)
|
||||
else:
|
||||
self.post = None
|
||||
|
||||
def time2tsamp(self, time: float):
|
||||
return round(time * self._wave.smp_s / self._subsampling)
|
||||
|
||||
|
@ -91,7 +106,6 @@ class PerFrameCache:
|
|||
''')
|
||||
class CorrelationTriggerConfig(ITriggerConfig):
|
||||
# get_trigger
|
||||
use_edge_trigger: bool = True
|
||||
edge_strength: float = 10.0
|
||||
trigger_diameter: float = 0.5
|
||||
|
||||
|
@ -106,6 +120,10 @@ class CorrelationTriggerConfig(ITriggerConfig):
|
|||
# region Legacy Aliases
|
||||
trigger_strength = Alias('edge_strength')
|
||||
falloff_width = Alias('buffer_falloff')
|
||||
|
||||
# Problem: InitVar with default values are (wrongly) accessible on object instances.
|
||||
# use_edge_trigger is False but self.use_edge_trigger is True, wtf?
|
||||
use_edge_trigger: bool = True
|
||||
# endregion
|
||||
|
||||
def __post_init__(self):
|
||||
|
@ -114,6 +132,15 @@ class CorrelationTriggerConfig(ITriggerConfig):
|
|||
# TODO trigger_falloff >= 0
|
||||
self._validate_param('buffer_falloff', 0, np.inf)
|
||||
|
||||
if self.use_edge_trigger:
|
||||
if self.post:
|
||||
warnings.warn(
|
||||
"Ignoring old `CorrelationTriggerConfig.use_edge_trigger` flag, "
|
||||
"overriden by newer `post` flag."
|
||||
)
|
||||
else:
|
||||
self.post = ZeroCrossingTriggerConfig()
|
||||
|
||||
def _validate_param(self, key: str, begin, end):
|
||||
value = getattr(self, key)
|
||||
if not begin <= value <= end:
|
||||
|
@ -124,7 +151,6 @@ class CorrelationTriggerConfig(ITriggerConfig):
|
|||
@register_trigger(CorrelationTriggerConfig)
|
||||
class CorrelationTrigger(Trigger):
|
||||
MIN_AMPLITUDE = 0.01
|
||||
ZERO_CROSSING_SCAN = 256
|
||||
cfg: CorrelationTriggerConfig
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -138,15 +164,6 @@ class CorrelationTrigger(Trigger):
|
|||
# Create correlation buffer (containing a series of old data)
|
||||
self._buffer = np.zeros(self._buffer_nsamp, dtype=FLOAT) # type: np.ndarray[FLOAT]
|
||||
|
||||
# Create zero crossing trigger, for postprocessing results
|
||||
self._zero_trigger = ZeroCrossingTrigger(
|
||||
self._wave,
|
||||
ITriggerConfig(),
|
||||
tsamp=self.ZERO_CROSSING_SCAN,
|
||||
subsampling=1,
|
||||
fps=self._fps
|
||||
)
|
||||
|
||||
# Precompute edge trigger step
|
||||
self._windowed_step = self._calc_step()
|
||||
|
||||
|
@ -207,7 +224,6 @@ class CorrelationTrigger(Trigger):
|
|||
|
||||
def get_trigger(self, index: int, cache: 'PerFrameCache') -> int:
|
||||
N = self._buffer_nsamp
|
||||
use_edge_trigger = self.cfg.use_edge_trigger
|
||||
|
||||
# Get data
|
||||
subsampling = self._subsampling
|
||||
|
@ -267,8 +283,8 @@ class CorrelationTrigger(Trigger):
|
|||
aligned = self._wave.get_around(trigger, self._buffer_nsamp, subsampling)
|
||||
self._update_buffer(aligned, period)
|
||||
|
||||
if use_edge_trigger:
|
||||
return self._zero_trigger.get_trigger(trigger, cache)
|
||||
if self.post:
|
||||
return self.post.get_trigger(trigger, cache)
|
||||
else:
|
||||
return trigger
|
||||
|
||||
|
@ -358,6 +374,12 @@ def lerp(x: np.ndarray, y: np.ndarray, a: float):
|
|||
|
||||
# ZeroCrossingTrigger
|
||||
|
||||
@register_config
|
||||
class ZeroCrossingTriggerConfig(ITriggerConfig):
|
||||
pass
|
||||
|
||||
|
||||
@register_trigger(ZeroCrossingTriggerConfig)
|
||||
class ZeroCrossingTrigger(Trigger):
|
||||
# ZeroCrossingTrigger is only used as a postprocessing trigger.
|
||||
# subsampling is only passed 1, for improved precision.
|
||||
|
|
|
@ -4,7 +4,8 @@ from matplotlib.axes import Axes
|
|||
from matplotlib.figure import Figure
|
||||
|
||||
from ovgenpy import triggers
|
||||
from ovgenpy.triggers import CorrelationTriggerConfig, CorrelationTrigger, PerFrameCache
|
||||
from ovgenpy.triggers import CorrelationTriggerConfig, CorrelationTrigger, \
|
||||
PerFrameCache, ZeroCrossingTriggerConfig
|
||||
from ovgenpy.wave import Wave
|
||||
|
||||
triggers.SHOW_TRIGGER = False
|
||||
|
@ -19,6 +20,16 @@ def cfg(request):
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope='session', params=[None, ZeroCrossingTriggerConfig()])
|
||||
def post_cfg(request):
|
||||
post = request.param
|
||||
return CorrelationTriggerConfig(
|
||||
use_edge_trigger=False,
|
||||
responsiveness=1,
|
||||
post=post
|
||||
)
|
||||
|
||||
|
||||
# I regret adding the nsamp_frame parameter. It makes unit tests hard.
|
||||
|
||||
FPS = 60
|
||||
|
@ -71,7 +82,6 @@ def test_trigger_subsampling(cfg: CorrelationTriggerConfig):
|
|||
|
||||
for i in range(1, iters):
|
||||
offset = trigger.get_trigger(x0, cache)
|
||||
print(offset)
|
||||
|
||||
# Debugging CorrelationTrigger.get_trigger:
|
||||
# from matplotlib import pyplot as plt
|
||||
|
@ -84,15 +94,40 @@ def test_trigger_subsampling(cfg: CorrelationTriggerConfig):
|
|||
# After truncation, corr[mid+1] is almost identical to corr[mid], for
|
||||
# reasons I don't understand (mid+1 > mid because dithering?).
|
||||
if not cfg.use_edge_trigger:
|
||||
assert (offset - x0) % subsampling == 0
|
||||
assert abs(offset - x0) < 10
|
||||
assert (offset - x0) % subsampling == 0, f'iteration {i}'
|
||||
assert abs(offset - x0) < 10, f'iteration {i}'
|
||||
|
||||
# The edge trigger activates at x0+1=24001. Likely related: it triggers
|
||||
# when moving from <=0 to >0. This is a necessary evil, in order to
|
||||
# recognize 0-to-positive edges while testing tests/impulse24000.wav .
|
||||
|
||||
else:
|
||||
assert abs(offset - x0) <= 2
|
||||
# If assertion fails, remove it.
|
||||
assert (offset - x0) % subsampling != 0, f'iteration {i}'
|
||||
assert abs(offset - x0) <= 2, f'iteration {i}'
|
||||
|
||||
|
||||
def test_post_trigger_subsampling(post_cfg: CorrelationTriggerConfig):
|
||||
cfg = post_cfg
|
||||
|
||||
wave = Wave(None, 'tests/sine440.wav')
|
||||
iters = 5
|
||||
x0 = 24000
|
||||
subsampling = 4
|
||||
trigger = cfg(wave, tsamp=100, subsampling=subsampling, fps=FPS)
|
||||
|
||||
cache = PerFrameCache()
|
||||
for i in range(1, iters):
|
||||
offset = trigger.get_trigger(x0, cache)
|
||||
|
||||
if not cfg.post:
|
||||
assert (offset - x0) % subsampling == 0, f'iteration {i}'
|
||||
assert abs(offset - x0) < 10, f'iteration {i}'
|
||||
|
||||
else:
|
||||
# If assertion fails, remove it.
|
||||
assert (offset - x0) % subsampling != 0, f'iteration {i}'
|
||||
assert abs(offset - x0) <= 2, f'iteration {i}'
|
||||
|
||||
|
||||
def test_trigger_subsampling_edges(cfg: CorrelationTriggerConfig):
|
||||
|
|
Ładowanie…
Reference in New Issue