2018-12-06 20:27:09 +00:00
|
|
|
import attr
|
2018-07-14 10:36:49 +00:00
|
|
|
import matplotlib.pyplot as plt
|
2019-02-26 05:49:42 +00:00
|
|
|
import numpy as np
|
2018-07-15 13:04:50 +00:00
|
|
|
import pytest
|
2018-07-14 10:36:49 +00:00
|
|
|
from matplotlib.axes import Axes
|
|
|
|
from matplotlib.figure import Figure
|
2019-02-26 05:49:42 +00:00
|
|
|
from pytest_cases import pytest_fixture_plus
|
2018-07-14 10:36:49 +00:00
|
|
|
|
2018-12-20 10:31:55 +00:00
|
|
|
from corrscope import triggers
|
2019-01-03 08:57:30 +00:00
|
|
|
from corrscope.triggers import (
|
|
|
|
CorrelationTriggerConfig,
|
|
|
|
CorrelationTrigger,
|
|
|
|
PerFrameCache,
|
|
|
|
ZeroCrossingTriggerConfig,
|
2019-02-26 05:49:42 +00:00
|
|
|
SpectrumConfig,
|
2019-04-13 13:09:43 +00:00
|
|
|
correlate_data,
|
|
|
|
correlate_spectrum,
|
2019-01-03 08:57:30 +00:00
|
|
|
)
|
2018-12-20 10:31:55 +00:00
|
|
|
from corrscope.wave import Wave
|
2018-07-14 10:36:49 +00:00
|
|
|
|
2019-03-07 05:38:52 +00:00
|
|
|
parametrize = pytest.mark.parametrize
|
|
|
|
|
|
|
|
|
2018-07-14 10:36:49 +00:00
|
|
|
triggers.SHOW_TRIGGER = False
|
|
|
|
|
|
|
|
|
2018-12-06 20:27:09 +00:00
|
|
|
def cfg_template(**kwargs) -> CorrelationTriggerConfig:
|
|
|
|
""" Not identical to default_config() template. """
|
|
|
|
cfg = CorrelationTriggerConfig(
|
2019-03-07 05:38:52 +00:00
|
|
|
edge_strength=2, responsiveness=1, buffer_falloff=0.5
|
2018-12-06 20:27:09 +00:00
|
|
|
)
|
|
|
|
return attr.evolve(cfg, **kwargs)
|
|
|
|
|
|
|
|
|
2019-02-26 05:49:42 +00:00
|
|
|
@pytest_fixture_plus
|
2019-03-07 05:38:52 +00:00
|
|
|
@parametrize("trigger_diameter", [None, 0.5])
|
|
|
|
@parametrize("pitch_tracking", [None, SpectrumConfig()])
|
2019-03-14 12:26:41 +00:00
|
|
|
@parametrize("slope_strength", [0, 100])
|
|
|
|
def cfg(trigger_diameter, pitch_tracking, slope_strength):
|
2019-02-26 05:49:42 +00:00
|
|
|
return cfg_template(
|
2019-03-14 12:26:41 +00:00
|
|
|
trigger_diameter=trigger_diameter,
|
|
|
|
pitch_tracking=pitch_tracking,
|
|
|
|
slope_strength=slope_strength,
|
|
|
|
slope_width=0.14,
|
2019-02-26 05:49:42 +00:00
|
|
|
)
|
2018-07-15 13:04:50 +00:00
|
|
|
|
|
|
|
|
2018-08-26 01:51:01 +00:00
|
|
|
FPS = 60
|
2018-08-25 23:35:09 +00:00
|
|
|
|
2019-03-13 09:07:45 +00:00
|
|
|
is_odd = parametrize("is_odd", [False, True])
|
|
|
|
|
|
|
|
|
|
|
|
# CorrelationTrigger overall tests
|
2019-01-03 08:57:30 +00:00
|
|
|
|
2019-03-13 09:07:45 +00:00
|
|
|
|
|
|
|
@is_odd
|
|
|
|
@parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()])
|
|
|
|
def test_trigger(cfg: CorrelationTriggerConfig, is_odd: bool, post_trigger):
|
|
|
|
"""Ensures that trigger can locate
|
|
|
|
the first positive sample of a -+ step exactly,
|
2019-03-26 23:17:08 +00:00
|
|
|
without off-by-1 errors.
|
|
|
|
|
|
|
|
See CorrelationTrigger and Wave.get_around() docstrings.
|
|
|
|
"""
|
2019-03-13 09:07:45 +00:00
|
|
|
wave = Wave("tests/step2400.wav")
|
|
|
|
cfg = attr.evolve(cfg, post_trigger=post_trigger)
|
2018-07-14 22:42:10 +00:00
|
|
|
|
|
|
|
iters = 5
|
|
|
|
plot = False
|
2019-03-13 09:07:45 +00:00
|
|
|
x0 = 2400
|
|
|
|
x = x0 - 50
|
|
|
|
trigger: CorrelationTrigger = cfg(wave, 400 + int(is_odd), stride=1, fps=FPS)
|
2018-07-14 10:36:49 +00:00
|
|
|
|
2018-07-14 22:42:10 +00:00
|
|
|
if plot:
|
|
|
|
BIG = 0.95
|
|
|
|
SMALL = 0.05
|
2019-01-03 08:57:30 +00:00
|
|
|
fig, axes = plt.subplots(
|
|
|
|
iters, gridspec_kw=dict(top=BIG, right=BIG, bottom=SMALL, left=SMALL)
|
|
|
|
) # type: Figure, Axes
|
2018-07-14 22:42:10 +00:00
|
|
|
fig.tight_layout()
|
|
|
|
else:
|
|
|
|
axes = range(iters)
|
2018-07-14 10:36:49 +00:00
|
|
|
|
|
|
|
for i, ax in enumerate(axes):
|
|
|
|
if i:
|
2018-10-29 00:46:07 +00:00
|
|
|
offset = trigger.get_trigger(x, PerFrameCache())
|
2019-04-10 09:51:54 +00:00
|
|
|
assert offset == x0, offset
|
2018-07-14 22:42:10 +00:00
|
|
|
if plot:
|
|
|
|
ax.plot(trigger._buffer, label=str(i))
|
|
|
|
ax.grid()
|
|
|
|
|
|
|
|
if plot:
|
|
|
|
plt.show()
|
2018-07-29 09:07:00 +00:00
|
|
|
|
|
|
|
|
2019-03-09 05:03:32 +00:00
|
|
|
@parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()])
|
|
|
|
def test_post_stride(post_trigger):
|
2019-03-13 09:07:45 +00:00
|
|
|
"""
|
|
|
|
Test that stride is respected when post_trigger is disabled,
|
|
|
|
and ignored when post_trigger is enabled.
|
|
|
|
"""
|
2019-03-09 05:03:32 +00:00
|
|
|
cfg = cfg_template(post_trigger=post_trigger)
|
2018-07-29 09:07:00 +00:00
|
|
|
|
2019-03-07 05:38:52 +00:00
|
|
|
wave = Wave("tests/sine440.wav")
|
2018-07-29 09:07:00 +00:00
|
|
|
iters = 5
|
|
|
|
x0 = 24000
|
2018-11-17 23:47:32 +00:00
|
|
|
stride = 4
|
|
|
|
trigger = cfg(wave, tsamp=100, stride=stride, fps=FPS)
|
2018-07-29 09:07:00 +00:00
|
|
|
|
2018-10-29 00:46:07 +00:00
|
|
|
cache = PerFrameCache()
|
2018-07-29 09:07:00 +00:00
|
|
|
for i in range(1, iters):
|
2018-10-29 00:46:07 +00:00
|
|
|
offset = trigger.get_trigger(x0, cache)
|
2018-07-29 09:07:00 +00:00
|
|
|
|
2019-03-09 05:03:32 +00:00
|
|
|
if not cfg.post_trigger:
|
2019-01-03 08:57:30 +00:00
|
|
|
assert (offset - x0) % stride == 0, f"iteration {i}"
|
|
|
|
assert abs(offset - x0) < 10, f"iteration {i}"
|
2018-07-29 09:07:00 +00:00
|
|
|
|
|
|
|
else:
|
2018-10-29 02:12:02 +00:00
|
|
|
# If assertion fails, remove it.
|
2019-01-03 08:57:30 +00:00
|
|
|
assert (offset - x0) % stride != 0, f"iteration {i}"
|
|
|
|
assert abs(offset - x0) <= 2, f"iteration {i}"
|
2018-10-29 02:12:02 +00:00
|
|
|
|
|
|
|
|
2019-03-09 05:03:32 +00:00
|
|
|
@parametrize("post_trigger", [None, ZeroCrossingTriggerConfig()])
|
2019-03-07 05:38:52 +00:00
|
|
|
@parametrize("double_negate", [False, True])
|
2019-03-09 05:03:32 +00:00
|
|
|
def test_trigger_direction(post_trigger, double_negate):
|
2019-03-07 05:38:52 +00:00
|
|
|
"""
|
|
|
|
Right now, MainTrigger is responsible for negating wave.amplification
|
|
|
|
if edge_direction == -1.
|
|
|
|
And triggers should not actually access edge_direction.
|
|
|
|
"""
|
2018-10-29 02:12:02 +00:00
|
|
|
|
2019-03-07 05:38:52 +00:00
|
|
|
index = 2400
|
|
|
|
wave = Wave("tests/step2400.wav")
|
2018-10-29 02:12:02 +00:00
|
|
|
|
2019-03-07 05:38:52 +00:00
|
|
|
if double_negate:
|
|
|
|
wave.amplification = -1
|
2019-03-09 05:03:32 +00:00
|
|
|
cfg = cfg_template(post_trigger=post_trigger, edge_direction=-1)
|
2019-03-07 05:38:52 +00:00
|
|
|
else:
|
2019-03-09 05:03:32 +00:00
|
|
|
cfg = cfg_template(post_trigger=post_trigger)
|
2018-10-29 02:12:02 +00:00
|
|
|
|
2019-03-07 05:38:52 +00:00
|
|
|
trigger = cfg(wave, 100, 1, FPS)
|
|
|
|
cfg.edge_direction = None
|
|
|
|
assert trigger._wave.amplification == 1
|
2018-10-29 02:12:02 +00:00
|
|
|
|
2019-03-07 05:38:52 +00:00
|
|
|
cache = PerFrameCache()
|
|
|
|
for dx in [-10, 10, 0]:
|
|
|
|
assert trigger.get_trigger(index + dx, cache) == index
|
2018-07-29 09:07:00 +00:00
|
|
|
|
|
|
|
|
2019-03-13 09:07:45 +00:00
|
|
|
def test_trigger_out_of_bounds(cfg: CorrelationTriggerConfig):
|
|
|
|
"""Ensure out-of-bounds triggering with stride does not crash.
|
|
|
|
(why does stride matter? IDK.)"""
|
2019-01-08 22:28:42 +00:00
|
|
|
wave = Wave("tests/sine440.wav")
|
2018-07-29 09:07:00 +00:00
|
|
|
# period = 48000 / 440 = 109.(09)*
|
|
|
|
|
2018-11-17 23:47:32 +00:00
|
|
|
stride = 4
|
|
|
|
trigger = cfg(wave, tsamp=100, stride=stride, fps=FPS)
|
|
|
|
# real window_samp = window_samp*stride
|
2018-07-29 09:07:00 +00:00
|
|
|
# period = 109
|
|
|
|
|
2018-10-29 00:46:07 +00:00
|
|
|
trigger.get_trigger(0, PerFrameCache())
|
|
|
|
trigger.get_trigger(-1000, PerFrameCache())
|
|
|
|
trigger.get_trigger(50000, PerFrameCache())
|
2018-07-29 09:07:00 +00:00
|
|
|
|
|
|
|
|
2019-03-13 09:07:45 +00:00
|
|
|
def test_when_does_trigger_recalc_window():
|
2018-12-06 20:27:09 +00:00
|
|
|
cfg = cfg_template(recalc_semitones=1.0)
|
2019-01-08 22:28:42 +00:00
|
|
|
wave = Wave("tests/sine440.wav")
|
2018-11-17 23:47:32 +00:00
|
|
|
trigger: CorrelationTrigger = cfg(wave, tsamp=1000, stride=1, fps=FPS)
|
2018-08-27 01:35:54 +00:00
|
|
|
|
|
|
|
for x in [0, 1, 1000]:
|
|
|
|
assert trigger._is_window_invalid(x), x
|
|
|
|
|
|
|
|
trigger._prev_period = 100
|
|
|
|
|
|
|
|
for x in [99, 101]:
|
|
|
|
assert not trigger._is_window_invalid(x), x
|
|
|
|
for x in [0, 80, 120]:
|
|
|
|
assert trigger._is_window_invalid(x), x
|
|
|
|
|
|
|
|
trigger._prev_period = 0
|
|
|
|
|
|
|
|
x = 0
|
|
|
|
assert not trigger._is_window_invalid(x), x
|
|
|
|
for x in [1, 100]:
|
|
|
|
assert trigger._is_window_invalid(x), x
|
|
|
|
|
|
|
|
|
2019-03-13 09:07:45 +00:00
|
|
|
# Test post triggering by itself
|
|
|
|
|
|
|
|
|
|
|
|
def test_post_trigger_radius():
|
|
|
|
"""
|
|
|
|
Ensure ZeroCrossingTrigger has no off-by-1 errors when locating edges,
|
|
|
|
and slides at a fixed rate if no edge is found.
|
|
|
|
"""
|
|
|
|
wave = Wave("tests/step2400.wav")
|
|
|
|
center = 2400
|
|
|
|
radius = 5
|
|
|
|
|
|
|
|
cfg = ZeroCrossingTriggerConfig()
|
|
|
|
post = cfg(wave, radius, 1, FPS)
|
|
|
|
|
|
|
|
cache = PerFrameCache(mean=0)
|
|
|
|
|
|
|
|
for offset in range(-radius, radius + 1):
|
|
|
|
assert post.get_trigger(center + offset, cache) == center, offset
|
|
|
|
|
|
|
|
for offset in [radius + 1, radius + 2, 100]:
|
|
|
|
assert post.get_trigger(center - offset, cache) == center - offset + radius
|
|
|
|
assert post.get_trigger(center + offset, cache) == center + offset - radius
|
|
|
|
|
|
|
|
|
|
|
|
# Test pitch-tracking (spectrum)
|
|
|
|
|
|
|
|
|
2019-04-13 13:09:43 +00:00
|
|
|
@parametrize("correlate", [correlate_data, correlate_spectrum])
|
|
|
|
def test_correlate_offset(correlate):
|
2019-02-26 05:49:42 +00:00
|
|
|
"""
|
|
|
|
Catches bug where writing N instead of Ncorr
|
|
|
|
prevented function from returning positive numbers.
|
|
|
|
"""
|
2019-04-13 13:09:43 +00:00
|
|
|
if correlate == correlate_spectrum:
|
|
|
|
approx = lambda x: pytest.approx(x, rel=0.5)
|
|
|
|
else:
|
|
|
|
approx = lambda x: x
|
2019-02-26 05:49:42 +00:00
|
|
|
|
|
|
|
np.random.seed(31337)
|
|
|
|
|
|
|
|
# Ensure autocorrelation on random data returns peak at 0.
|
|
|
|
N = 100
|
|
|
|
spectrum = np.random.random(N)
|
2019-04-13 13:09:43 +00:00
|
|
|
assert correlate(spectrum, spectrum, 12).peak == approx(0)
|
2019-02-26 05:49:42 +00:00
|
|
|
|
|
|
|
# Ensure cross-correlation of time-shifted impulses works.
|
|
|
|
# Assume wave where y=[i==99].
|
|
|
|
wave = np.eye(N)[::-1]
|
|
|
|
# Taking a slice beginning at index i will produce an impulse at 99-i.
|
|
|
|
left = wave[30]
|
|
|
|
right = wave[40]
|
|
|
|
|
|
|
|
# We need to slide `left` to the right by 10 samples, and vice versa.
|
|
|
|
for radius in [None, 12]:
|
2019-04-13 13:09:43 +00:00
|
|
|
assert correlate(data=left, prev_buffer=right, radius=radius).peak == approx(10)
|
|
|
|
assert correlate(data=right, prev_buffer=left, radius=radius).peak == approx(
|
|
|
|
-10
|
2019-02-26 05:49:42 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
|
2018-08-26 05:44:42 +00:00
|
|
|
# Test the ability to load legacy TriggerConfig
|
|
|
|
|
2019-01-03 08:57:30 +00:00
|
|
|
|
2018-08-26 05:44:42 +00:00
|
|
|
def test_load_trigger_config():
|
2018-12-20 10:31:55 +00:00
|
|
|
from corrscope.config import yaml
|
2018-08-26 05:44:42 +00:00
|
|
|
|
|
|
|
# Ensure no exceptions
|
2019-01-03 08:57:30 +00:00
|
|
|
yaml.load(
|
|
|
|
"""\
|
2018-08-26 05:44:42 +00:00
|
|
|
!CorrelationTriggerConfig
|
|
|
|
trigger_strength: 3
|
|
|
|
responsiveness: 0.2
|
|
|
|
falloff_width: 2
|
2019-01-03 08:57:30 +00:00
|
|
|
"""
|
|
|
|
)
|
|
|
|
|
2018-08-26 05:44:42 +00:00
|
|
|
|
2018-07-29 09:07:00 +00:00
|
|
|
# TODO test_period get_period()
|