diff --git a/tests/test_wave.py b/tests/test_wave.py index aa08b4a..78f5bf3 100644 --- a/tests/test_wave.py +++ b/tests/test_wave.py @@ -1,12 +1,15 @@ import warnings +from hypothesis import given +import hypothesis.strategies as hs import numpy as np from numpy.testing import assert_allclose import pytest from delayed_assert import expect, assert_expectations +from corrscope.config import CorrError from corrscope.utils.scipy_wavfile import WavFileWarning -from corrscope.wave import Wave +from corrscope.wave import Wave, Flatten prefix = "tests/wav-formats/" wave_paths = [ @@ -38,9 +41,11 @@ def test_wave(wave_path): assert not [str(w) for w in warns] +# Stereo tests + + def test_stereo_merge(): - """ Ensure stereo channels are combined properly, when indexing by slices - *or* ints. """ + """Test indexing Wave by slices *or* ints. Flatten using default SumAvg mode.""" # Contains a full-scale sine wave in left channel, and silence in right. # λ=100, nsamp=2000 @@ -71,11 +76,47 @@ def test_stereo_merge(): check_bound(wave[:]) +AllFlattens = hs.sampled_from(list(Flatten.__members__.values())) +ValidFlattens = hs.sampled_from(Flatten.modes) + + +@given(AllFlattens) +def test_stereo_flatten_modes(flatten: Flatten): + """Ensures all Flatten modes are handled properly + for stereo and mono signals.""" + wave = Wave(None, "tests/stereo in-phase.wav") + + if flatten not in Flatten.modes: + with pytest.raises(CorrError): + wave.with_flatten(flatten) + return + else: + wave = wave.with_flatten(flatten) + + nsamp = wave.nsamp + data = wave[:] + + # wave.data == 2-D array of shape (nsamp, nchan) + if flatten == Flatten.Stereo: + assert data.shape == (nsamp, 2) + else: + assert data.shape == (nsamp,) + if flatten & Flatten.Diff: + np.testing.assert_equal(data, 0) + else: + assert flatten & Flatten.Sum + if flatten & Flatten.IsAvg: + pass # FIXME + + def test_stereo_mmap(): wave = Wave(None, prefix + "stereo-sine-left-2000.wav") assert isinstance(wave.data, np.memmap) +# Miscellaneous tests + + def test_wave_subsampling(): wave = Wave(None, "tests/sine440.wav") # period = 48000 / 440 = 109.(09)*