corrscope/tests/test_channel.py

195 wiersze
6.2 KiB
Python

from contextlib import ExitStack
from typing import Optional
import hypothesis.strategies as hs
import numpy as np
import pytest
from hypothesis import given
from unittest.mock import patch
import corrscope.channel
import corrscope.corrscope
from corrscope.channel import ChannelConfig, Channel, DefaultLabel
from corrscope.corrscope import template_config, CorrScope, BenchmarkMode, Arguments
from corrscope.triggers import NullTriggerConfig
from corrscope.util import coalesce
from corrscope.wave import Flatten
positive = hs.integers(min_value=1, max_value=100)
real = hs.floats(min_value=0, max_value=100)
maybe_real = hs.one_of(hs.none(), real)
bools = hs.booleans()
default_labels = hs.sampled_from(DefaultLabel)
@given(
# Channel
c_amplification=maybe_real,
c_trigger_width=positive,
c_render_width=positive,
# Global
amplification=real,
trigger_ms=positive,
render_ms=positive,
tsub=positive,
rsub=positive,
default_label=hs.sampled_from(DefaultLabel),
override_label=bools,
)
def test_config_channel_integration(
# Channel
c_amplification: Optional[float],
c_trigger_width: int,
c_render_width: int,
# Global
amplification: float,
trigger_ms: int,
render_ms: int,
tsub: int,
rsub: int,
default_label: DefaultLabel,
override_label: bool,
):
"""(Tautologically) verify:
- channel. r_samp (given cfg)
- channel.t/r_stride (given cfg.*_subsampling/*_width)
- trigger._tsamp, _stride
- renderer's method calls(samp, stride)
- rendered label (channel.label, given cfg, corr_cfg.default_label)
"""
with ExitStack() as stack:
# region setup test variables
corrscope.corrscope.PRINT_TIMESTAMP = False # Cleanup Hypothesis testing logs
Wave = stack.enter_context(patch.object(corrscope.channel, "Wave"))
wave = Wave.return_value
def get_around(sample: int, return_nsamp: int, stride: int):
return np.zeros(return_nsamp)
wave.get_around.side_effect = get_around
wave.with_flatten.return_value = wave
wave.nsamp = 10000
wave.smp_s = 48000
ccfg = ChannelConfig(
"tests/sine440.wav",
trigger_width=c_trigger_width,
render_width=c_render_width,
amplification=c_amplification,
label="label" if override_label else "",
)
def get_cfg():
return template_config(
trigger_ms=trigger_ms,
render_ms=render_ms,
trigger_subsampling=tsub,
render_subsampling=rsub,
amplification=amplification,
channels=[ccfg],
default_label=default_label,
trigger=NullTriggerConfig(),
benchmark_mode=BenchmarkMode.OUTPUT,
)
# endregion
cfg = get_cfg()
channel = Channel(ccfg, cfg)
# Ensure cfg.width_ms etc. are correct
assert cfg.trigger_ms == trigger_ms
assert cfg.render_ms == render_ms
# Ensure channel.window_samp, trigger_subsampling, render_subsampling are correct.
def ideal_samp(width_ms, sub):
width_s = width_ms / 1000
return pytest.approx(
round(width_s * channel.trigger_wave.smp_s / sub), rel=1e-6
)
ideal_tsamp = ideal_samp(cfg.trigger_ms, tsub)
ideal_rsamp = ideal_samp(cfg.render_ms, rsub)
assert channel._render_samp == ideal_rsamp
assert channel._trigger_stride == tsub * c_trigger_width
assert channel.render_stride == rsub * c_render_width
# Ensure amplification override works
args, kwargs = Wave.call_args
assert kwargs["amplification"] == coalesce(c_amplification, amplification)
## Ensure trigger uses channel.window_samp and _trigger_stride.
trigger = channel.trigger
assert trigger._tsamp == ideal_tsamp
assert trigger._stride == channel._trigger_stride
## Ensure corrscope calls render using channel._render_samp and _render_stride.
corr = CorrScope(cfg, Arguments(cfg_dir=".", outputs=[]))
renderer = stack.enter_context(
patch.object(CorrScope, "_load_renderer")
).return_value
corr.play()
# Only Channel.get_render_around() (not NullTrigger) calls wave.get_around().
(_sample, _return_nsamp, _subsampling), kwargs = wave.get_around.call_args
assert _return_nsamp == channel._render_samp
assert _subsampling == channel.render_stride
# Inspect arguments to renderer.update_main_lines()
# datas: List[np.ndarray]
(datas,), kwargs = renderer.update_main_lines.call_args
render_data = datas[0]
assert len(render_data) == channel._render_samp
# Inspect arguments to renderer.add_labels().
(labels,), kwargs = renderer.add_labels.call_args
label = labels[0]
if override_label:
assert label == "label"
else:
if default_label is DefaultLabel.FileName:
assert label == "sine440"
elif default_label is DefaultLabel.Number:
assert label == "1"
else:
assert label == ""
# line_color is tested in test_renderer.py
@pytest.mark.parametrize("filename", ["tests/sine440.wav", "tests/stereo in-phase.wav"])
@pytest.mark.parametrize(
("global_stereo", "chan_stereo"),
[
[Flatten.SumAvg, None],
[Flatten.Stereo, None],
[Flatten.SumAvg, Flatten.Stereo],
[Flatten.Stereo, Flatten.SumAvg],
[Flatten.Stereo, "1 0"],
],
)
def test_per_channel_stereo(
filename: str, global_stereo: Flatten, chan_stereo: Optional[Flatten]
):
"""Ensure you can enable/disable stereo on a per-channel basis."""
stereo = coalesce(chan_stereo, global_stereo)
# Test render wave.
cfg = template_config(render_stereo=global_stereo)
ccfg = ChannelConfig("tests/stereo in-phase.wav", render_stereo=chan_stereo)
channel = Channel(ccfg, cfg)
# Render wave *must* return stereo.
assert channel.render_wave[0:1].ndim == 2
data = channel.render_wave.get_around(0, return_nsamp=4, stride=1)
assert data.ndim == 2
if "stereo" in filename:
assert channel.render_wave._flatten == stereo
assert data.shape[1] == (2 if stereo is Flatten.Stereo else 1)