diff --git a/corrscope/corrscope.py b/corrscope/corrscope.py index e917e67..94ce1dd 100644 --- a/corrscope/corrscope.py +++ b/corrscope/corrscope.py @@ -1,12 +1,12 @@ # -*- coding: utf-8 -*- import time from contextlib import ExitStack, contextmanager -from enum import unique, Enum +from enum import unique from fractions import Fraction from pathlib import Path from types import SimpleNamespace from typing import Iterator -from typing import Optional, List, Union, Callable, cast +from typing import Optional, List, Callable, cast import attr @@ -16,7 +16,6 @@ from corrscope.config import KeywordAttrs, DumpEnumAsStr, CorrError, with_units from corrscope.layout import LayoutConfig from corrscope.renderer import MatplotlibRenderer, RendererConfig, Renderer from corrscope.triggers import ( - MainTriggerConfig, CorrelationTriggerConfig, PerFrameCache, CorrelationTrigger, @@ -118,9 +117,8 @@ def default_config(**kwargs) -> Config: edge_strength=2, responsiveness=0.5, buffer_falloff=0.5, - pitch_tracking=SpectrumConfig() - # Removed due to speed hit. - # post=LocalPostTriggerConfig(strength=0.1), + pitch_tracking=SpectrumConfig(), + # post_trigger=ZeroCrossingTriggerConfig(), ), channels=[], layout=LayoutConfig(orientation="v", ncols=1), diff --git a/corrscope/triggers.py b/corrscope/triggers.py index e94208e..c1ae788 100644 --- a/corrscope/triggers.py +++ b/corrscope/triggers.py @@ -44,8 +44,7 @@ class MainTriggerConfig(_TriggerConfig, KeywordAttrs, always_dump="edge_directio edge_direction: int = 1 # Optional trigger for postprocessing - # TODO rename to post_trigger - post: Optional["PostTriggerConfig"] = None + post_trigger: Optional["PostTriggerConfig"] = None post_radius: Optional[int] = 3 @property @@ -59,17 +58,17 @@ class MainTriggerConfig(_TriggerConfig, KeywordAttrs, always_dump="edge_directio if self.edge_direction not in [-1, 1]: raise CorrError(f"{obj_name(self)}.edge_direction must be {{-1, 1}}") - if self.post: - self.post.parent = self + if self.post_trigger: + self.post_trigger.parent = self if self.post_radius is None: name = obj_name(self) raise CorrError( - f"Cannot supply {name}.post without supplying {name}.post_radius" + f"Cannot supply {name}.post_trigger without supplying {name}.post_radius" ) class PostTriggerConfig(_TriggerConfig, KeywordAttrs): - parent: MainTriggerConfig = attr.ib(init=False) + parent: MainTriggerConfig = attr.ib(init=False) # TODO Unused pass @@ -132,10 +131,10 @@ class MainTrigger(_Trigger, ABC): self._wave.amplification *= self.cfg.edge_direction cfg = self.cfg - if cfg.post: + if cfg.post_trigger: # Create a post-processing trigger, with narrow nsamp and stride=1. # This improves speed and precision. - self.post = cfg.post(self._wave, cfg.post_nsamp, 1, self._fps) + self.post = cfg.post_trigger(self._wave, cfg.post_nsamp, 1, self._fps) else: self.post = None diff --git a/tests/test_trigger.py b/tests/test_trigger.py index 91c0707..8a07218 100644 --- a/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -76,9 +76,9 @@ def test_trigger(cfg: CorrelationTriggerConfig): plt.show() -@parametrize("post", [None, ZeroCrossingTriggerConfig()]) -def test_post_stride(post): - cfg = cfg_template(post=post) +@parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()]) +def test_post_stride(post_trigger): + cfg = cfg_template(post_trigger=post_trigger) wave = Wave("tests/sine440.wav") iters = 5 @@ -90,7 +90,7 @@ def test_post_stride(post): for i in range(1, iters): offset = trigger.get_trigger(x0, cache) - if not cfg.post: + if not cfg.post_trigger: assert (offset - x0) % stride == 0, f"iteration {i}" assert abs(offset - x0) < 10, f"iteration {i}" @@ -100,9 +100,9 @@ def test_post_stride(post): assert abs(offset - x0) <= 2, f"iteration {i}" -@parametrize("post", [None, ZeroCrossingTriggerConfig()]) +@parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()]) @parametrize("double_negate", [False, True]) -def test_trigger_direction(post, double_negate): +def test_trigger_direction(post_trigger, double_negate): """ Right now, MainTrigger is responsible for negating wave.amplification if edge_direction == -1. @@ -114,9 +114,9 @@ def test_trigger_direction(post, double_negate): if double_negate: wave.amplification = -1 - cfg = cfg_template(post=post, edge_direction=-1) + cfg = cfg_template(post_trigger=post_trigger, edge_direction=-1) else: - cfg = cfg_template(post=post) + cfg = cfg_template(post_trigger=post_trigger) trigger = cfg(wave, 100, 1, FPS) cfg.edge_direction = None