From de285317f840c9a6544a17677d1967898e958a99 Mon Sep 17 00:00:00 2001 From: nyanpasu64 Date: Thu, 24 Jan 2019 22:27:36 -0800 Subject: [PATCH] Remove unneeded Wave validation, switch from Flag to Enum --- corrscope/wave.py | 19 ++++++++----------- tests/test_wave.py | 4 ++-- 2 files changed, 10 insertions(+), 13 deletions(-) diff --git a/corrscope/wave.py b/corrscope/wave.py index 1d07adb..1a51907 100644 --- a/corrscope/wave.py +++ b/corrscope/wave.py @@ -13,7 +13,7 @@ FLOAT = np.single @enum.unique -class Flatten(TypedEnumDump, enum.Flag): +class Flatten(TypedEnumDump, enum.Enum): """ How to flatten a stereo signal. (Channels beyond first 2 are ignored.) Flatten(0) == Flatten.Stereo == Flatten['Stereo'] @@ -63,21 +63,18 @@ class Wave: @flatten.setter def flatten(self, flatten: Flatten) -> None: - """ If self.is_mono, converts all non-Stereo modes to Mono. """ + # Reject invalid modes (including Mono). if flatten not in Flatten.modes: + # Flatten.Mono not in Flatten.modes. raise CorrError( f"Wave {self.wave_path} has invalid flatten mode {flatten} " f"not in {Flatten.modes}" ) + + # If self.is_mono, converts all non-Stereo modes to Mono. self._flatten = flatten - if self.is_mono: - if flatten != Flatten.Stereo: - self._flatten = Flatten.Mono - else: - if self.flatten == Flatten.Mono: - raise CorrError( - f"Cannot initialize stereo file {self.wave_path} with flatten=Mono" - ) + if self.is_mono and flatten != Flatten.Stereo: + self._flatten = Flatten.Mono def __init__( self, @@ -147,7 +144,7 @@ class Wave: data = data.reshape(-1) # ndarray.flatten() creates copy, is slow. elif flatten != Flatten.Stereo: # data.strides = (4,), so data == contiguous float32 - if flatten & Flatten.SumAvg: + if flatten == Flatten.SumAvg: data = data[..., 0] + data[..., 1] else: data = data[..., 0] - data[..., 1] diff --git a/tests/test_wave.py b/tests/test_wave.py index 76f98db..a3d69a8 100644 --- a/tests/test_wave.py +++ b/tests/test_wave.py @@ -116,14 +116,14 @@ def test_stereo_flatten_modes( assert data.shape == (nsamp,) # If DiffAvg and in-phase, L-R=0. - if flatten & Flatten.DiffAvg: + if flatten == Flatten.DiffAvg: if len(peaks) >= 2 and peaks[0] == peaks[1]: np.testing.assert_equal(data, 0) else: pass # If SumAvg, check average. else: - assert flatten & Flatten.SumAvg + assert flatten == Flatten.SumAvg assert_full_scale(data, np.mean(peaks))