diff --git a/corrscope/gui/__init__.py b/corrscope/gui/__init__.py index bc64689..447b038 100644 --- a/corrscope/gui/__init__.py +++ b/corrscope/gui/__init__.py @@ -32,6 +32,8 @@ from corrscope.gui.model_bind import ( rsetattr, Symbol, SymbolText, + BoundComboBox, + _call_all, ) from corrscope.gui.util import color2hex, Locked, find_ranges, TracebackDialog from corrscope.gui.view_mainwindow import MainWindow as Ui_MainWindow @@ -289,7 +291,7 @@ class MainWindow(qw.QMainWindow, Ui_MainWindow): if name: master_audio = "master_audio" self.model[master_audio] = name - self.model.update_widget[master_audio]() + self.model.update_all_bound(master_audio) def on_separate_render_dir_toggled(self, checked: bool): self.pref.separate_render_dir = checked @@ -580,7 +582,7 @@ def nrow_ncol_property(altered: str, unaltered: str) -> property: if val > 0: setattr(self.cfg.layout, altered, val) setattr(self.cfg.layout, unaltered, None) - self.update_widget["layout__" + unaltered]() + self.update_all_bound("layout__" + unaltered) elif val == 0: setattr(self.cfg.layout, altered, None) else: @@ -637,8 +639,12 @@ class ConfigModel(PresentationModel): master_audio = path_fix_property("master_audio") # Stereo flattening - combo_symbol_text["trigger_stereo"] = list(flatten_no_stereo.items()) - combo_symbol_text["render_stereo"] = list(flatten_modes.items()) + combo_symbol_text["trigger_stereo"] = list(flatten_no_stereo.items()) + [ + (BoundComboBox.Custom, "Custom") + ] + combo_symbol_text["render_stereo"] = list(flatten_modes.items()) + [ + (BoundComboBox.Custom, "Custom") + ] # Trigger @property @@ -826,6 +832,7 @@ class ChannelModel(qc.QAbstractTableModel): Column("wav_path", path_strip_quotes, "", "WAV Path"), Column("amplification", float, None, "Amplification\n(override)"), Column("line_color", str, None, "Line Color"), + Column("render_stereo", str, None, "Render Stereo\nDownmix"), Column("trigger_width", int, None, "Trigger Width ×"), Column("render_width", int, None, "Render Width ×"), Column("trigger__buffer_strength", float, None), @@ -865,7 +872,7 @@ class ChannelModel(qc.QAbstractTableModel): # data TRIGGER = "trigger__" - def data(self, index: QModelIndex, role=Qt.DisplayRole) -> qc.QVariant: + def data(self, index: QModelIndex, role=Qt.DisplayRole) -> Any: col = index.column() row = index.row() diff --git a/corrscope/gui/model_bind.py b/corrscope/gui/model_bind.py index 1c926fa..71769d3 100644 --- a/corrscope/gui/model_bind.py +++ b/corrscope/gui/model_bind.py @@ -1,5 +1,6 @@ import functools import operator +from collections import defaultdict from typing import * import attr @@ -10,8 +11,8 @@ from PyQt5.QtWidgets import QWidget from corrscope.config import CorrError, DumpableAttrs, get_units from corrscope.gui.util import color2hex -from corrscope.utils.trigger_util import lerp from corrscope.util import obj_name, perr +from corrscope.utils.trigger_util import lerp if TYPE_CHECKING: from corrscope.gui import MainWindow @@ -37,6 +38,12 @@ Symbol = Hashable SymbolText = Tuple[Symbol, str] + +def _call_all(updaters: List[WidgetUpdater]): + for updater in updaters: + updater() + + # Data binding presentation-model class PresentationModel(qc.QObject): """ Key-value MVP presentation-model. @@ -54,7 +61,7 @@ class PresentationModel(qc.QObject): def __init__(self, cfg: DumpableAttrs): super().__init__() self.cfg = cfg - self.update_widget: Dict[str, WidgetUpdater] = {} + self.update_widget: Dict[str, List[WidgetUpdater]] = defaultdict(list) def __getitem__(self, item: str) -> Any: try: @@ -76,8 +83,11 @@ class PresentationModel(qc.QObject): def set_cfg(self, cfg: DumpableAttrs): self.cfg = cfg - for updater in self.update_widget.values(): - updater() + for updater_list in self.update_widget.values(): + _call_all(updater_list) + + def update_all_bound(self, key: str): + _call_all(self.update_widget[key]) SKIP_BINDING = "skip" @@ -137,7 +147,7 @@ class BoundWidget(QWidget): if connect_to_model: # Allow widget to be updated by other events. - model.update_widget[path] = self.cfg2gui + model.update_widget[path].append(self.cfg2gui) # Allow pmodel to be changed by widget. self.gui_changed.connect(self.set_model) @@ -147,6 +157,8 @@ class BoundWidget(QWidget): perr(path) raise + # TODO unbind_widget(), model.update_widget[path].remove(self.cfg2gui)? + def calc_error_palette(self) -> QPalette: """ Palette with red background, used for widgets with invalid input. """ error_palette = QPalette(self.palette()) @@ -179,7 +191,9 @@ class BoundWidget(QWidget): gui_changed: ClassVar[Signal] def set_model(self, value): - pass + for updater in self.pmodel.update_widget[self.path]: + if updater != self.cfg2gui: + updater() def blend_colors( @@ -212,6 +226,7 @@ def model_setter(value_type: type) -> Callable[..., None]: except CorrError: self.setPalette(self.error_palette) else: + BoundWidget.set_model(self, value) self.setPalette(self.default_palette) return set_model @@ -285,6 +300,9 @@ class BoundComboBox(qw.QComboBox, BoundWidget): combo_symbol_text: Sequence[SymbolText] symbol2idx: Dict[Symbol, int] + Custom = object() + custom_if_unmatched: bool + # noinspection PyAttributeOutsideInit def bind_widget(self, model: PresentationModel, path: str, *args, **kwargs) -> None: # Effectively enum values. @@ -292,9 +310,13 @@ class BoundComboBox(qw.QComboBox, BoundWidget): # symbol2idx[enum] = combo-box index self.symbol2idx = {} + self.custom_if_unmatched = False for i, (symbol, text) in enumerate(self.combo_symbol_text): self.symbol2idx[symbol] = i + if symbol is self.Custom: + self.custom_if_unmatched = True + # Pretty-printed text self.addItem(text) @@ -302,7 +324,12 @@ class BoundComboBox(qw.QComboBox, BoundWidget): # combobox.index = pmodel.attr def set_gui(self, model_value: Any) -> None: - combo_index = self.symbol2idx[self._symbol_from_value(model_value)] + symbol = self._symbol_from_value(model_value) + if self.custom_if_unmatched and symbol not in self.symbol2idx: + combo_index = self.symbol2idx[self.Custom] + else: + combo_index = self.symbol2idx[self._symbol_from_value(model_value)] + self.setCurrentIndex(combo_index) @staticmethod @@ -316,7 +343,9 @@ class BoundComboBox(qw.QComboBox, BoundWidget): def set_model(self, combo_index: int): assert isinstance(combo_index, int) combo_symbol, _ = self.combo_symbol_text[combo_index] - self.pmodel[self.path] = self._value_from_symbol(combo_symbol) + if combo_symbol is not self.Custom: + self.pmodel[self.path] = self._value_from_symbol(combo_symbol) + BoundWidget.set_model(self, None) @staticmethod def _value_from_symbol(symbol: Symbol): @@ -445,6 +474,7 @@ class _ColorText(BoundLineEdit): self.setPalette(self.default_palette) self.hex_color.emit(value or "") # calls button.set_color() self.pmodel[self.path] = value + BoundWidget.set_model(self, value) def sizeHint(self) -> qc.QSize: """Reduce the width taken up by #rrggbb color text boxes.""" diff --git a/corrscope/gui/view_mainwindow.py b/corrscope/gui/view_mainwindow.py index 9372a69..c9dfe58 100644 --- a/corrscope/gui/view_mainwindow.py +++ b/corrscope/gui/view_mainwindow.py @@ -174,9 +174,15 @@ class MainWindow(QWidget): with add_row(s, "", BoundComboBox) as self.trigger_stereo: pass + with add_row(s, tr("Downmix"), BoundLineEdit, name="trigger_stereo"): + pass + with add_row(s, "", BoundComboBox) as self.render_stereo: pass + with add_row(s, tr("Downmix"), BoundLineEdit, name="render_stereo"): + pass + with append_widget(s, QGroupBox) as self.dockStereo_2: set_layout(s, QFormLayout) diff --git a/corrscope/wave.py b/corrscope/wave.py index c891746..b33d096 100644 --- a/corrscope/wave.py +++ b/corrscope/wave.py @@ -1,55 +1,114 @@ import copy import enum -import warnings -from enum import auto from typing import Union, List import numpy as np import corrscope.utils.scipy.wavfile as wavfile -from corrscope.config import CorrError, CorrWarning, TypedEnumDump +from corrscope.config import CorrError, TypedEnumDump FLOAT = np.single +# Depends on FLOAT +from corrscope.utils.windows import rightpad + @enum.unique -class Flatten(TypedEnumDump): +class Flatten(str, TypedEnumDump): """ How to flatten a stereo signal. (Channels beyond first 2 are ignored.) Flatten(0) == Flatten.Stereo == Flatten['Stereo'] """ # Keep both channels. - Stereo = 0 + Stereo = "stereo" # Mono - Mono = auto() # NOT publicly exposed + Mono = "1" # NOT publicly exposed # Take sum or difference. - SumAvg = auto() - DiffAvg = auto() + SumAvg = "1 1" + DiffAvg = "1, -1" + + def __str__(self): + return self.value + + # Both our app and GUI treat: + # - Flatten.SumAvg -> "sum of all channels" + # - "1 1" -> "assert nchan == 2, left + right". + # - "1 0" -> "assert nchan == 2, left". + def __eq__(self, other): + return self is other + + def __hash__(self): + return hash(self.value) modes: List["Flatten"] +assert "1" == str(Flatten.Mono) +assert not "1" == Flatten.Mono +assert not Flatten.Mono == "1" + +FlattenOrStr = Union[Flatten, str] + + +def calc_flatten_matrix(flatten: FlattenOrStr, stereo_nchan: int) -> np.ndarray: + """ Raises CorrError on invalid input. + + If flatten is Flatten.Stereo, returns shape=(nchan,nchan) identity matrix. + - (N,nchan) @ (nchan,nchan) = (N,nchan). + + Otherwise, returns shape=(nchan) flattening matrix. + - (N,nchan) @ (nchan) = (N) + + https://docs.scipy.org/doc/numpy/reference/generated/numpy.matmul.html#numpy.matmul + ''' + If the second argument is 1-D, + it is promoted to a matrix by appending a 1 to its dimensions. + After matrix multiplication the appended 1 is removed." + ''' + """ + + if flatten is Flatten.Stereo: + # 2D identity (results in 2-dim data) + flatten_matrix = np.eye(stereo_nchan, dtype=FLOAT) + + # 1D (results in 1-dim data) + elif flatten is Flatten.SumAvg: + flatten_matrix = np.ones(stereo_nchan, dtype=FLOAT) / stereo_nchan + + elif flatten is Flatten.DiffAvg: + flatten_matrix = calc_flatten_matrix(str(flatten), stereo_nchan) + flatten_matrix = rightpad(flatten_matrix, stereo_nchan, 0) + + else: + words = flatten.replace(",", " ").split() + try: + flatten_matrix = np.array([FLOAT(word) for word in words]) + except ValueError as e: + raise CorrError("Invalid stereo flattening matrix") from e + + flatten_abs_sum = np.sum(np.abs(flatten_matrix)) + if flatten_abs_sum == 0: + raise CorrError("Stereo flattening matrix must have nonzero elements") + + flatten_matrix /= flatten_abs_sum + + assert flatten_matrix.dtype == FLOAT, flatten_matrix.dtype + return flatten_matrix + + _rejected_modes = {Flatten.Mono} Flatten.modes = [f for f in Flatten.__members__.values() if f not in _rejected_modes] class Wave: - __slots__ = """ - wave_path - amplification offset - smp_s data return_channels _flatten is_mono - nsamp dtype - center max_val - """.split() - smp_s: int - data: "np.ndarray" - """2-D array of shape (nsamp, nchan)""" + data: np.ndarray - _flatten: Flatten + _flatten: FlattenOrStr + flatten_matrix: np.ndarray @property def flatten(self) -> Flatten: @@ -64,13 +123,13 @@ class Wave: return self._flatten @flatten.setter - def flatten(self, flatten: Flatten) -> None: + def flatten(self, flatten: FlattenOrStr) -> None: # Reject invalid modes (including Mono). - if flatten not in Flatten.modes: # type: ignore + if flatten in _rejected_modes: # Flatten.Mono not in Flatten.modes. raise CorrError( f"Wave {self.wave_path} has invalid flatten mode {flatten} " - f"not in {Flatten.modes}" + f"not a numeric string, nor in {Flatten.modes}" ) # If self.is_mono, converts all non-Stereo modes to Mono. @@ -78,6 +137,8 @@ class Wave: if self.is_mono and flatten != Flatten.Stereo: self._flatten = Flatten.Mono + self.flatten_matrix = calc_flatten_matrix(self._flatten, self.stereo_nchan) + def __init__( self, wave_path: str, @@ -87,28 +148,26 @@ class Wave: self.wave_path = wave_path self.amplification = amplification self.offset = 0 + + # self.data: 2-D array of shape (nsamp, nchan) self.smp_s, self.data = wavfile.read(wave_path, mmap=True) assert self.data.ndim in [1, 2] self.is_mono = self.data.ndim == 1 - self.flatten = flatten self.return_channels = False # Cast self.data to stereo (nsamp, nchan) if self.is_mono: self.data.shape = (-1, 1) - self.nsamp, stereo_nchan = self.data.shape - if stereo_nchan > 2: - warnings.warn( - f"File {wave_path} has {stereo_nchan} channels, " - f"only first 2 will be used", - CorrWarning, - ) + self.nsamp, self.stereo_nchan = self.data.shape - dtype = self.data.dtype + # Depends on self.stereo_nchan + self.flatten = flatten # Calculate scaling factor. + dtype = self.data.dtype + def is_type(parent: type) -> bool: return np.issubdtype(dtype, parent) @@ -152,16 +211,7 @@ class Wave: data: np.ndarray = self.data[index].astype(FLOAT, subok=False, copy=True) # Flatten stereo to mono. - flatten = self._flatten # Potentially faster than property getter. - if flatten == Flatten.Mono: - data = data.reshape(-1) # ndarray.flatten() creates copy, is slow. - elif flatten != Flatten.Stereo: - # data.strides = (4,), so data == contiguous float32 - if flatten == Flatten.SumAvg: - data = data[..., 0] + data[..., 1] - else: - data = data[..., 0] - data[..., 1] - data /= 2 + data = data @ self.flatten_matrix data -= self.center data *= self.amplification / self.max_val diff --git a/tests/test_channel.py b/tests/test_channel.py index 4fa8cc8..314d758 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -145,6 +145,7 @@ def test_config_channel_width_stride( [Flatten.Stereo, None], [Flatten.SumAvg, Flatten.Stereo], [Flatten.Stereo, Flatten.SumAvg], + [Flatten.Stereo, "1 0"], ], ) def test_per_channel_stereo( diff --git a/tests/test_wave.py b/tests/test_wave.py index 2680c66..f2a9e80 100644 --- a/tests/test_wave.py +++ b/tests/test_wave.py @@ -8,7 +8,7 @@ from delayed_assert import expect, assert_expectations from corrscope.config import CorrError from corrscope.utils.scipy.wavfile import WavFileWarning -from corrscope.wave import Wave, Flatten +from corrscope.wave import Wave, Flatten, calc_flatten_matrix prefix = "tests/wav-formats/" wave_paths = [ @@ -43,6 +43,55 @@ def test_wave(wave_path): # Stereo tests +def arr(*args): + return np.array(args) + + +def test_calc_flatten_matrix(): + nchan = 3 + + # Test Stereo + np.testing.assert_equal(calc_flatten_matrix(Flatten.Stereo, nchan), np.eye(nchan)) + + # Test SumAvg on various channel counts + np.testing.assert_equal(calc_flatten_matrix(Flatten.SumAvg, 1), [1]) + np.testing.assert_equal(calc_flatten_matrix(Flatten.SumAvg, 2), [0.5, 0.5]) + np.testing.assert_equal(calc_flatten_matrix(Flatten.SumAvg, 4), [0.25] * 4) + + # Test DiffAvg on various channel counts + # (Wave will use Mono instead of DiffAvg, on mono audio signals. + # But ensure it doesn't crash anyway.) + np.testing.assert_equal(calc_flatten_matrix(Flatten.DiffAvg, 1), [0.5]) + np.testing.assert_equal(calc_flatten_matrix(Flatten.DiffAvg, 2), [0.5, -0.5]) + np.testing.assert_equal(calc_flatten_matrix(Flatten.DiffAvg, 4), [0.5, -0.5, 0, 0]) + + # Test Mono + np.testing.assert_equal(calc_flatten_matrix(Flatten.Mono, 1), [1]) + + # Test custom strings and delimiters + out = arr(1, 2, 1) + nchan = 3 + np.testing.assert_equal(calc_flatten_matrix(",1,2,1,", nchan), out / sum(out)) + np.testing.assert_equal(calc_flatten_matrix(" 1, 2, 1 ", nchan), out / sum(out)) + np.testing.assert_equal(calc_flatten_matrix("1 2 1", nchan), out / sum(out)) + + # Test negative values + nchan = 2 + np.testing.assert_equal(calc_flatten_matrix("1, -1", nchan), arr(1, -1) / 2) + np.testing.assert_equal(calc_flatten_matrix("-1, 1", nchan), arr(-1, 1) / 2) + np.testing.assert_equal(calc_flatten_matrix("-1, -1", nchan), arr(-1, -1) / 2) + + # Test invalid inputs + with pytest.raises(CorrError): + calc_flatten_matrix("", 0) + + with pytest.raises(CorrError): + calc_flatten_matrix("1 -1 uwu", 3) + + with pytest.raises(CorrError): + calc_flatten_matrix("0 0", 2) + + def test_stereo_merge(): """Test indexing Wave by slices *or* ints. Flatten using default SumAvg mode.""" @@ -75,7 +124,7 @@ def test_stereo_merge(): check_bound(wave[:]) -AllFlattens = Flatten.__members__.values() +AllFlattens = [*Flatten.__members__.values(), "1 1", "1 0", "1 -1"] @pytest.mark.parametrize("flatten", AllFlattens) @@ -107,7 +156,7 @@ def test_stereo_flatten_modes( assert nchan == len(peaks) wave = Wave(path) - if flatten not in Flatten.modes: + if flatten is Flatten.Mono: with pytest.raises(CorrError): wave.with_flatten(flatten, return_channels) return @@ -135,9 +184,9 @@ def test_stereo_flatten_modes( else: pass # If SumAvg, check average. - else: - assert flatten == Flatten.SumAvg + elif flatten == Flatten.SumAvg: assert_full_scale(data, np.mean(peaks)) + # Don't test custom string modes for now. def assert_full_scale(data, peak):