Fix trigger bugs, fix post trigger width, add tests (#231)

- Fix CorrelationTrigger off-by-one on odd buffer widths
- Fix bug where `post_radius` was *2 to become diameter, then treated as radius
- Fix errors in ZeroCrossingTrigger
- Cleanup test_trigger.py
pull/357/head
nyanpasu64 2019-03-13 02:07:45 -07:00 zatwierdzone przez GitHub
rodzic 39992a222d
commit e2144f3689
4 zmienionych plików z 68 dodań i 26 usunięć

Wyświetl plik

@ -49,13 +49,6 @@ class MainTriggerConfig(
post_trigger: Optional["PostTriggerConfig"] = None
post_radius: Optional[int] = with_units("smp", default=3)
@property
def post_nsamp(self) -> Optional[int]:
if self.post_radius is not None:
return 2 * self.post_radius + 1
else:
return None
def __attrs_post_init__(self):
if self.edge_direction not in [-1, 1]:
raise CorrError(f"{obj_name(self)}.edge_direction must be {{-1, 1}}")
@ -74,9 +67,7 @@ class PostTriggerConfig(_TriggerConfig, KeywordAttrs):
pass
def register_trigger(
config_t: Type[_TriggerConfig]
) -> "Callable[[Type[_Trigger]], Type[_Trigger]]": # my god mypy-strict sucks
def register_trigger(config_t: Type[_TriggerConfig]):
""" @register_trigger(FooTriggerConfig)
def FooTrigger(): ...
"""
@ -136,7 +127,7 @@ class MainTrigger(_Trigger, ABC):
if cfg.post_trigger:
# Create a post-processing trigger, with narrow nsamp and stride=1.
# This improves speed and precision.
self.post = cfg.post_trigger(self._wave, cfg.post_nsamp, 1, self._fps)
self.post = cfg.post_trigger(self._wave, cfg.post_radius, 1, self._fps)
else:
self.post = None
@ -393,6 +384,15 @@ class CorrelationTriggerConfig(MainTriggerConfig, always_dump="pitch_tracking"):
@register_trigger(CorrelationTriggerConfig)
class CorrelationTrigger(MainTrigger):
"""
Assume that if get_trigger(x) == x, then data[[x-1, x]] == [<0, >0].
- edge detectors [halfN = N//2] > 0.
- So wave.get_around(x)[N//2] > 0.
- So wave.get_around(x) = [x - N//2 : ...]
test_trigger() checks that get_around() works properly, for even/odd N.
"""
cfg: CorrelationTriggerConfig
@property
@ -793,7 +793,7 @@ class ZeroCrossingTrigger(PostTrigger):
def get_trigger(self, index: int, cache: "PerFrameCache") -> int:
# 'cache' is unused.
tsamp = self._tsamp
radius = self._tsamp
if not 0 <= index < self._wave.nsamp:
return index
@ -809,7 +809,7 @@ class ZeroCrossingTrigger(PostTrigger):
else: # self._wave[sample] == 0
return index + 1
data = self._wave[index : index + (direction * tsamp) : direction]
data = self._wave[index : index + direction * (radius + 1) : direction]
# TODO remove unnecessary complexity, since diameter is probably under 10.
intercepts = find(data, test)
try:
@ -817,7 +817,7 @@ class ZeroCrossingTrigger(PostTrigger):
return index + (delta * direction) + int(value <= 0)
except StopIteration: # No zero-intercepts
return index + (direction * tsamp)
return index + (direction * radius)
# noinspection PyUnreachableCode
"""

Wyświetl plik

@ -203,8 +203,8 @@ class Wave:
Copies self.data[...] """
distance = return_nsamp * stride
end = sample + distance // 2
begin = end - distance
begin = sample - distance // 2
end = begin + distance
return self._get(begin, end, stride)
def get_s(self) -> float:

Plik binarny nie jest wyświetlany.

Wyświetl plik

@ -39,19 +39,28 @@ def cfg(trigger_diameter, pitch_tracking):
)
# I regret adding the nsamp_frame parameter. It makes unit tests hard.
FPS = 60
is_odd = parametrize("is_odd", [False, True])
def test_trigger(cfg: CorrelationTriggerConfig):
wave = Wave("tests/impulse24000.wav")
# CorrelationTrigger overall tests
@is_odd
@parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()])
def test_trigger(cfg: CorrelationTriggerConfig, is_odd: bool, post_trigger):
"""Ensures that trigger can locate
the first positive sample of a -+ step exactly,
without off-by-1 errors."""
wave = Wave("tests/step2400.wav")
cfg = attr.evolve(cfg, post_trigger=post_trigger)
iters = 5
plot = False
x0 = 24000
x = x0 - 500
trigger: CorrelationTrigger = cfg(wave, 4000, stride=1, fps=FPS)
x0 = 2400
x = x0 - 50
trigger: CorrelationTrigger = cfg(wave, 400 + int(is_odd), stride=1, fps=FPS)
if plot:
BIG = 0.95
@ -78,6 +87,10 @@ def test_trigger(cfg: CorrelationTriggerConfig):
@parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()])
def test_post_stride(post_trigger):
"""
Test that stride is respected when post_trigger is disabled,
and ignored when post_trigger is enabled.
"""
cfg = cfg_template(post_trigger=post_trigger)
wave = Wave("tests/sine440.wav")
@ -127,7 +140,9 @@ def test_trigger_direction(post_trigger, double_negate):
assert trigger.get_trigger(index + dx, cache) == index
def test_trigger_stride_edges(cfg: CorrelationTriggerConfig):
def test_trigger_out_of_bounds(cfg: CorrelationTriggerConfig):
"""Ensure out-of-bounds triggering with stride does not crash.
(why does stride matter? IDK.)"""
wave = Wave("tests/sine440.wav")
# period = 48000 / 440 = 109.(09)*
@ -141,7 +156,7 @@ def test_trigger_stride_edges(cfg: CorrelationTriggerConfig):
trigger.get_trigger(50000, PerFrameCache())
def test_trigger_should_recalc_window():
def test_when_does_trigger_recalc_window():
cfg = cfg_template(recalc_semitones=1.0)
wave = Wave("tests/sine440.wav")
trigger: CorrelationTrigger = cfg(wave, tsamp=1000, stride=1, fps=FPS)
@ -164,7 +179,34 @@ def test_trigger_should_recalc_window():
assert trigger._is_window_invalid(x), x
# Test pitch-invariant triggering using spectrum
# Test post triggering by itself
def test_post_trigger_radius():
"""
Ensure ZeroCrossingTrigger has no off-by-1 errors when locating edges,
and slides at a fixed rate if no edge is found.
"""
wave = Wave("tests/step2400.wav")
center = 2400
radius = 5
cfg = ZeroCrossingTriggerConfig()
post = cfg(wave, radius, 1, FPS)
cache = PerFrameCache(mean=0)
for offset in range(-radius, radius + 1):
assert post.get_trigger(center + offset, cache) == center, offset
for offset in [radius + 1, radius + 2, 100]:
assert post.get_trigger(center - offset, cache) == center - offset + radius
assert post.get_trigger(center + offset, cache) == center + offset - radius
# Test pitch-tracking (spectrum)
def test_correlate_offset():
"""
Catches bug where writing N instead of Ncorr