diff --git a/corrscope/triggers.py b/corrscope/triggers.py index f63eaf5..08d5e2a 100644 --- a/corrscope/triggers.py +++ b/corrscope/triggers.py @@ -49,13 +49,6 @@ class MainTriggerConfig( post_trigger: Optional["PostTriggerConfig"] = None post_radius: Optional[int] = with_units("smp", default=3) - @property - def post_nsamp(self) -> Optional[int]: - if self.post_radius is not None: - return 2 * self.post_radius + 1 - else: - return None - def __attrs_post_init__(self): if self.edge_direction not in [-1, 1]: raise CorrError(f"{obj_name(self)}.edge_direction must be {{-1, 1}}") @@ -74,9 +67,7 @@ class PostTriggerConfig(_TriggerConfig, KeywordAttrs): pass -def register_trigger( - config_t: Type[_TriggerConfig] -) -> "Callable[[Type[_Trigger]], Type[_Trigger]]": # my god mypy-strict sucks +def register_trigger(config_t: Type[_TriggerConfig]): """ @register_trigger(FooTriggerConfig) def FooTrigger(): ... """ @@ -136,7 +127,7 @@ class MainTrigger(_Trigger, ABC): if cfg.post_trigger: # Create a post-processing trigger, with narrow nsamp and stride=1. # This improves speed and precision. - self.post = cfg.post_trigger(self._wave, cfg.post_nsamp, 1, self._fps) + self.post = cfg.post_trigger(self._wave, cfg.post_radius, 1, self._fps) else: self.post = None @@ -393,6 +384,15 @@ class CorrelationTriggerConfig(MainTriggerConfig, always_dump="pitch_tracking"): @register_trigger(CorrelationTriggerConfig) class CorrelationTrigger(MainTrigger): + """ + Assume that if get_trigger(x) == x, then data[[x-1, x]] == [<0, >0]. + - edge detectors [halfN = N//2] > 0. + - So wave.get_around(x)[N//2] > 0. + - So wave.get_around(x) = [x - N//2 : ...] + + test_trigger() checks that get_around() works properly, for even/odd N. + """ + cfg: CorrelationTriggerConfig @property @@ -793,7 +793,7 @@ class ZeroCrossingTrigger(PostTrigger): def get_trigger(self, index: int, cache: "PerFrameCache") -> int: # 'cache' is unused. - tsamp = self._tsamp + radius = self._tsamp if not 0 <= index < self._wave.nsamp: return index @@ -809,7 +809,7 @@ class ZeroCrossingTrigger(PostTrigger): else: # self._wave[sample] == 0 return index + 1 - data = self._wave[index : index + (direction * tsamp) : direction] + data = self._wave[index : index + direction * (radius + 1) : direction] # TODO remove unnecessary complexity, since diameter is probably under 10. intercepts = find(data, test) try: @@ -817,7 +817,7 @@ class ZeroCrossingTrigger(PostTrigger): return index + (delta * direction) + int(value <= 0) except StopIteration: # No zero-intercepts - return index + (direction * tsamp) + return index + (direction * radius) # noinspection PyUnreachableCode """ diff --git a/corrscope/wave.py b/corrscope/wave.py index 4da1014..63b01f1 100644 --- a/corrscope/wave.py +++ b/corrscope/wave.py @@ -203,8 +203,8 @@ class Wave: Copies self.data[...] """ distance = return_nsamp * stride - end = sample + distance // 2 - begin = end - distance + begin = sample - distance // 2 + end = begin + distance return self._get(begin, end, stride) def get_s(self) -> float: diff --git a/tests/impulse24000.wav b/tests/impulse24000.wav deleted file mode 100644 index 3fa2fbd..0000000 Binary files a/tests/impulse24000.wav and /dev/null differ diff --git a/tests/test_trigger.py b/tests/test_trigger.py index 8a07218..9e43b99 100644 --- a/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -39,19 +39,28 @@ def cfg(trigger_diameter, pitch_tracking): ) -# I regret adding the nsamp_frame parameter. It makes unit tests hard. - FPS = 60 +is_odd = parametrize("is_odd", [False, True]) -def test_trigger(cfg: CorrelationTriggerConfig): - wave = Wave("tests/impulse24000.wav") + +# CorrelationTrigger overall tests + + +@is_odd +@parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()]) +def test_trigger(cfg: CorrelationTriggerConfig, is_odd: bool, post_trigger): + """Ensures that trigger can locate + the first positive sample of a -+ step exactly, + without off-by-1 errors.""" + wave = Wave("tests/step2400.wav") + cfg = attr.evolve(cfg, post_trigger=post_trigger) iters = 5 plot = False - x0 = 24000 - x = x0 - 500 - trigger: CorrelationTrigger = cfg(wave, 4000, stride=1, fps=FPS) + x0 = 2400 + x = x0 - 50 + trigger: CorrelationTrigger = cfg(wave, 400 + int(is_odd), stride=1, fps=FPS) if plot: BIG = 0.95 @@ -78,6 +87,10 @@ def test_trigger(cfg: CorrelationTriggerConfig): @parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()]) def test_post_stride(post_trigger): + """ + Test that stride is respected when post_trigger is disabled, + and ignored when post_trigger is enabled. + """ cfg = cfg_template(post_trigger=post_trigger) wave = Wave("tests/sine440.wav") @@ -127,7 +140,9 @@ def test_trigger_direction(post_trigger, double_negate): assert trigger.get_trigger(index + dx, cache) == index -def test_trigger_stride_edges(cfg: CorrelationTriggerConfig): +def test_trigger_out_of_bounds(cfg: CorrelationTriggerConfig): + """Ensure out-of-bounds triggering with stride does not crash. + (why does stride matter? IDK.)""" wave = Wave("tests/sine440.wav") # period = 48000 / 440 = 109.(09)* @@ -141,7 +156,7 @@ def test_trigger_stride_edges(cfg: CorrelationTriggerConfig): trigger.get_trigger(50000, PerFrameCache()) -def test_trigger_should_recalc_window(): +def test_when_does_trigger_recalc_window(): cfg = cfg_template(recalc_semitones=1.0) wave = Wave("tests/sine440.wav") trigger: CorrelationTrigger = cfg(wave, tsamp=1000, stride=1, fps=FPS) @@ -164,7 +179,34 @@ def test_trigger_should_recalc_window(): assert trigger._is_window_invalid(x), x -# Test pitch-invariant triggering using spectrum +# Test post triggering by itself + + +def test_post_trigger_radius(): + """ + Ensure ZeroCrossingTrigger has no off-by-1 errors when locating edges, + and slides at a fixed rate if no edge is found. + """ + wave = Wave("tests/step2400.wav") + center = 2400 + radius = 5 + + cfg = ZeroCrossingTriggerConfig() + post = cfg(wave, radius, 1, FPS) + + cache = PerFrameCache(mean=0) + + for offset in range(-radius, radius + 1): + assert post.get_trigger(center + offset, cache) == center, offset + + for offset in [radius + 1, radius + 2, 100]: + assert post.get_trigger(center - offset, cache) == center - offset + radius + assert post.get_trigger(center + offset, cache) == center + offset - radius + + +# Test pitch-tracking (spectrum) + + def test_correlate_offset(): """ Catches bug where writing N instead of Ncorr