kopia lustrzana https://github.com/corrscope/corrscope
Fix trigger bugs, fix post trigger width, add tests (#231)
- Fix CorrelationTrigger off-by-one on odd buffer widths - Fix bug where `post_radius` was *2 to become diameter, then treated as radius - Fix errors in ZeroCrossingTrigger - Cleanup test_trigger.pypull/357/head
rodzic
39992a222d
commit
e2144f3689
|
@ -49,13 +49,6 @@ class MainTriggerConfig(
|
|||
post_trigger: Optional["PostTriggerConfig"] = None
|
||||
post_radius: Optional[int] = with_units("smp", default=3)
|
||||
|
||||
@property
|
||||
def post_nsamp(self) -> Optional[int]:
|
||||
if self.post_radius is not None:
|
||||
return 2 * self.post_radius + 1
|
||||
else:
|
||||
return None
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
if self.edge_direction not in [-1, 1]:
|
||||
raise CorrError(f"{obj_name(self)}.edge_direction must be {{-1, 1}}")
|
||||
|
@ -74,9 +67,7 @@ class PostTriggerConfig(_TriggerConfig, KeywordAttrs):
|
|||
pass
|
||||
|
||||
|
||||
def register_trigger(
|
||||
config_t: Type[_TriggerConfig]
|
||||
) -> "Callable[[Type[_Trigger]], Type[_Trigger]]": # my god mypy-strict sucks
|
||||
def register_trigger(config_t: Type[_TriggerConfig]):
|
||||
""" @register_trigger(FooTriggerConfig)
|
||||
def FooTrigger(): ...
|
||||
"""
|
||||
|
@ -136,7 +127,7 @@ class MainTrigger(_Trigger, ABC):
|
|||
if cfg.post_trigger:
|
||||
# Create a post-processing trigger, with narrow nsamp and stride=1.
|
||||
# This improves speed and precision.
|
||||
self.post = cfg.post_trigger(self._wave, cfg.post_nsamp, 1, self._fps)
|
||||
self.post = cfg.post_trigger(self._wave, cfg.post_radius, 1, self._fps)
|
||||
else:
|
||||
self.post = None
|
||||
|
||||
|
@ -393,6 +384,15 @@ class CorrelationTriggerConfig(MainTriggerConfig, always_dump="pitch_tracking"):
|
|||
|
||||
@register_trigger(CorrelationTriggerConfig)
|
||||
class CorrelationTrigger(MainTrigger):
|
||||
"""
|
||||
Assume that if get_trigger(x) == x, then data[[x-1, x]] == [<0, >0].
|
||||
- edge detectors [halfN = N//2] > 0.
|
||||
- So wave.get_around(x)[N//2] > 0.
|
||||
- So wave.get_around(x) = [x - N//2 : ...]
|
||||
|
||||
test_trigger() checks that get_around() works properly, for even/odd N.
|
||||
"""
|
||||
|
||||
cfg: CorrelationTriggerConfig
|
||||
|
||||
@property
|
||||
|
@ -793,7 +793,7 @@ class ZeroCrossingTrigger(PostTrigger):
|
|||
|
||||
def get_trigger(self, index: int, cache: "PerFrameCache") -> int:
|
||||
# 'cache' is unused.
|
||||
tsamp = self._tsamp
|
||||
radius = self._tsamp
|
||||
|
||||
if not 0 <= index < self._wave.nsamp:
|
||||
return index
|
||||
|
@ -809,7 +809,7 @@ class ZeroCrossingTrigger(PostTrigger):
|
|||
else: # self._wave[sample] == 0
|
||||
return index + 1
|
||||
|
||||
data = self._wave[index : index + (direction * tsamp) : direction]
|
||||
data = self._wave[index : index + direction * (radius + 1) : direction]
|
||||
# TODO remove unnecessary complexity, since diameter is probably under 10.
|
||||
intercepts = find(data, test)
|
||||
try:
|
||||
|
@ -817,7 +817,7 @@ class ZeroCrossingTrigger(PostTrigger):
|
|||
return index + (delta * direction) + int(value <= 0)
|
||||
|
||||
except StopIteration: # No zero-intercepts
|
||||
return index + (direction * tsamp)
|
||||
return index + (direction * radius)
|
||||
|
||||
# noinspection PyUnreachableCode
|
||||
"""
|
||||
|
|
|
@ -203,8 +203,8 @@ class Wave:
|
|||
|
||||
Copies self.data[...] """
|
||||
distance = return_nsamp * stride
|
||||
end = sample + distance // 2
|
||||
begin = end - distance
|
||||
begin = sample - distance // 2
|
||||
end = begin + distance
|
||||
return self._get(begin, end, stride)
|
||||
|
||||
def get_s(self) -> float:
|
||||
|
|
Plik binarny nie jest wyświetlany.
|
@ -39,19 +39,28 @@ def cfg(trigger_diameter, pitch_tracking):
|
|||
)
|
||||
|
||||
|
||||
# I regret adding the nsamp_frame parameter. It makes unit tests hard.
|
||||
|
||||
FPS = 60
|
||||
|
||||
is_odd = parametrize("is_odd", [False, True])
|
||||
|
||||
def test_trigger(cfg: CorrelationTriggerConfig):
|
||||
wave = Wave("tests/impulse24000.wav")
|
||||
|
||||
# CorrelationTrigger overall tests
|
||||
|
||||
|
||||
@is_odd
|
||||
@parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()])
|
||||
def test_trigger(cfg: CorrelationTriggerConfig, is_odd: bool, post_trigger):
|
||||
"""Ensures that trigger can locate
|
||||
the first positive sample of a -+ step exactly,
|
||||
without off-by-1 errors."""
|
||||
wave = Wave("tests/step2400.wav")
|
||||
cfg = attr.evolve(cfg, post_trigger=post_trigger)
|
||||
|
||||
iters = 5
|
||||
plot = False
|
||||
x0 = 24000
|
||||
x = x0 - 500
|
||||
trigger: CorrelationTrigger = cfg(wave, 4000, stride=1, fps=FPS)
|
||||
x0 = 2400
|
||||
x = x0 - 50
|
||||
trigger: CorrelationTrigger = cfg(wave, 400 + int(is_odd), stride=1, fps=FPS)
|
||||
|
||||
if plot:
|
||||
BIG = 0.95
|
||||
|
@ -78,6 +87,10 @@ def test_trigger(cfg: CorrelationTriggerConfig):
|
|||
|
||||
@parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()])
|
||||
def test_post_stride(post_trigger):
|
||||
"""
|
||||
Test that stride is respected when post_trigger is disabled,
|
||||
and ignored when post_trigger is enabled.
|
||||
"""
|
||||
cfg = cfg_template(post_trigger=post_trigger)
|
||||
|
||||
wave = Wave("tests/sine440.wav")
|
||||
|
@ -127,7 +140,9 @@ def test_trigger_direction(post_trigger, double_negate):
|
|||
assert trigger.get_trigger(index + dx, cache) == index
|
||||
|
||||
|
||||
def test_trigger_stride_edges(cfg: CorrelationTriggerConfig):
|
||||
def test_trigger_out_of_bounds(cfg: CorrelationTriggerConfig):
|
||||
"""Ensure out-of-bounds triggering with stride does not crash.
|
||||
(why does stride matter? IDK.)"""
|
||||
wave = Wave("tests/sine440.wav")
|
||||
# period = 48000 / 440 = 109.(09)*
|
||||
|
||||
|
@ -141,7 +156,7 @@ def test_trigger_stride_edges(cfg: CorrelationTriggerConfig):
|
|||
trigger.get_trigger(50000, PerFrameCache())
|
||||
|
||||
|
||||
def test_trigger_should_recalc_window():
|
||||
def test_when_does_trigger_recalc_window():
|
||||
cfg = cfg_template(recalc_semitones=1.0)
|
||||
wave = Wave("tests/sine440.wav")
|
||||
trigger: CorrelationTrigger = cfg(wave, tsamp=1000, stride=1, fps=FPS)
|
||||
|
@ -164,7 +179,34 @@ def test_trigger_should_recalc_window():
|
|||
assert trigger._is_window_invalid(x), x
|
||||
|
||||
|
||||
# Test pitch-invariant triggering using spectrum
|
||||
# Test post triggering by itself
|
||||
|
||||
|
||||
def test_post_trigger_radius():
|
||||
"""
|
||||
Ensure ZeroCrossingTrigger has no off-by-1 errors when locating edges,
|
||||
and slides at a fixed rate if no edge is found.
|
||||
"""
|
||||
wave = Wave("tests/step2400.wav")
|
||||
center = 2400
|
||||
radius = 5
|
||||
|
||||
cfg = ZeroCrossingTriggerConfig()
|
||||
post = cfg(wave, radius, 1, FPS)
|
||||
|
||||
cache = PerFrameCache(mean=0)
|
||||
|
||||
for offset in range(-radius, radius + 1):
|
||||
assert post.get_trigger(center + offset, cache) == center, offset
|
||||
|
||||
for offset in [radius + 1, radius + 2, 100]:
|
||||
assert post.get_trigger(center - offset, cache) == center - offset + radius
|
||||
assert post.get_trigger(center + offset, cache) == center + offset - radius
|
||||
|
||||
|
||||
# Test pitch-tracking (spectrum)
|
||||
|
||||
|
||||
def test_correlate_offset():
|
||||
"""
|
||||
Catches bug where writing N instead of Ncorr
|
||||
|
|
Ładowanie…
Reference in New Issue