diff --git a/corrscope/channel.py b/corrscope/channel.py index cdcb0b2..fe3ffa6 100644 --- a/corrscope/channel.py +++ b/corrscope/channel.py @@ -62,8 +62,8 @@ class Channel: tflat = coalesce(cfg.trigger_stereo, corr_cfg.trigger_stereo) rflat = coalesce(cfg.render_stereo, corr_cfg.render_stereo) - self.trigger_wave = wave.with_flatten(tflat) - self.render_wave = wave.with_flatten(rflat) + self.trigger_wave = wave.with_flatten(tflat, return_channels=False) + self.render_wave = wave.with_flatten(rflat, return_channels=True) # `subsampling` increases `stride` and decreases `nsamp`. # `width` increases `stride` without changing `nsamp`. diff --git a/corrscope/layout.py b/corrscope/layout.py index d0ffa7b..6ee3d8e 100644 --- a/corrscope/layout.py +++ b/corrscope/layout.py @@ -1,16 +1,41 @@ -from typing import Optional, TypeVar, Callable, List, Generic, Tuple +import collections +import enum +from enum import auto +from typing import Optional, TypeVar, Callable, List, Iterable +import attr import numpy as np -from corrscope.config import DumpableAttrs, CorrError +from corrscope.config import DumpableAttrs, CorrError, DumpEnumAsStr from corrscope.util import ceildiv -class LayoutConfig(DumpableAttrs, always_dump="orientation"): - orientation: str = "h" +class Orientation(str, DumpEnumAsStr): + h = "h" + v = "v" + + +class StereoOrientation(str, DumpEnumAsStr): + h = "h" + v = "v" + overlay = "overlay" + + +assert Orientation.h == StereoOrientation.h +H = Orientation.h +V = Orientation.v +OVERLAY = StereoOrientation.overlay + + +class LayoutConfig(DumpableAttrs, always_dump="orientation stereo_orientation"): + orientation: Orientation = attr.ib(default="h", converter=Orientation) nrows: Optional[int] = None ncols: Optional[int] = None + stereo_orientation: StereoOrientation = attr.ib( + default="h", converter=StereoOrientation + ) + def __attrs_post_init__(self) -> None: if not self.nrows: self.nrows = None @@ -24,31 +49,90 @@ class LayoutConfig(DumpableAttrs, always_dump="orientation"): self.ncols = 1 +class Edges(enum.Flag): + NONE = 0 + Top = auto() + Left = auto() + Bottom = auto() + Right = auto() + + @staticmethod + def at(nrows: int, ncols: int, row: int, col: int): + if not nrows > 0: + raise ValueError(f"invalid nrows={nrows}, must be positive") + if not ncols > 0: + raise ValueError(f"invalid ncols={ncols}, must be positive") + if not 0 <= row < nrows: + raise ValueError(f"invalid row={row} not in [0 .. nrows={nrows})") + if not 0 <= col < ncols: + raise ValueError(f"invalid col={col} not in [0 .. ncols={ncols})") + + ret = Edges.NONE + if row == 0: + ret |= Edges.Top + if row + 1 == nrows: + ret |= Edges.Bottom + if col == 0: + ret |= Edges.Left + if col + 1 == ncols: + ret |= Edges.Right + return ret + + +def attr_idx_property(key: str, idx: int) -> property: + @property + def inner(self: "RegionSpec"): + return getattr(self, key)[idx] + + return inner + + +@attr.dataclass +class RegionSpec: + """ + - Origin is located at top-left. + - Row 0 = top. + - Row nrows-1 = bottom. + - Col 0 = left. + - Col ncols-1 = right. + """ + + size: np.ndarray + pos: np.ndarray + + nrow = attr_idx_property("size", 0) + ncol = attr_idx_property("size", 1) + row = attr_idx_property("pos", 0) + col = attr_idx_property("pos", 1) + + screen_edges: "Edges" + wave_edges: "Edges" + + Region = TypeVar("Region") -RegionFactory = Callable[[int, int], Region] # f(row, column) -> Region +RegionFactory = Callable[[RegionSpec], Region] # f(row, column) -> Region class RendererLayout: - VALID_ORIENTATIONS = ["h", "v"] - - def __init__(self, cfg: LayoutConfig, nplots: int): + def __init__(self, cfg: LayoutConfig, wave_nchans: List[int]): self.cfg = cfg - self.nplots = nplots - - # Setup layout - self.nrows, self.ncols = self._calc_layout() + self.nwaves = len(wave_nchans) + self.wave_nchans = wave_nchans self.orientation = cfg.orientation - if self.orientation not in self.VALID_ORIENTATIONS: - raise CorrError( - f"Invalid orientation {self.orientation} not in " - f"{self.VALID_ORIENTATIONS}" - ) + self.stereo_orientation = cfg.stereo_orientation - def _calc_layout(self) -> Tuple[int, int]: + # Setup layout + self._calc_layout() + + # Shape of wave slots + wave_nrow: int + wave_ncol: int + + def _calc_layout(self) -> None: """ - Inputs: self.cfg, self.waves - :return: (nrows, ncols) + Inputs: self.cfg, self.stereo_nchan + Outputs: self.wave_nrow, ncol """ cfg = self.cfg @@ -56,7 +140,7 @@ class RendererLayout: nrows = cfg.nrows if nrows is None: raise ValueError("impossible cfg: nrows is None and true") - ncols = ceildiv(self.nplots, nrows) + ncols = ceildiv(self.nwaves, nrows) else: if cfg.ncols is None: raise ValueError( @@ -64,38 +148,110 @@ class RendererLayout: "(__attrs_post_init__ not called?)" ) ncols = cfg.ncols - nrows = ceildiv(self.nplots, ncols) + nrows = ceildiv(self.nwaves, ncols) - return nrows, ncols + self.wave_nrow = nrows + self.wave_ncol = ncols - def arrange(self, region_factory: RegionFactory[Region]) -> List[Region]: - """ Generates an array of regions. - - index, row, column are fed into region_factory in a row-major order [row][col]. - The results are possibly reshaped into column-major order [col][row]. + def arrange(self, region_factory: RegionFactory[Region]) -> List[List[Region]]: """ - nspaces = self.nrows * self.ncols - inds = np.arange(nspaces) - rows, cols = np.unravel_index(inds, (self.nrows, self.ncols)) + (row, column) are fed into region_factory in a row-major order [row][col]. + Stereo channel pairs are extracted. + The results are possibly reshaped into column-major order [col][row]. - row_col = list(zip(rows, cols)) - regions = np.empty(len(row_col), dtype=object) # type: np.ndarray[Region] - regions[:] = [region_factory(*rc) for rc in row_col] + :return arr[wave][channel] = Region + """ - regions2d = regions.reshape( - (self.nrows, self.ncols) - ) # type: np.ndarray[Region] + wave_spaces = self.wave_nrow * self.wave_ncol + inds = np.arange(wave_spaces) - # if column major: - if self.orientation == "v": - regions2d = regions2d.T + # Compute location of each wave. + if self.orientation == V: + # column major + cols, rows = np.unravel_index(inds, (self.wave_ncol, self.wave_nrow)) + else: + # row major + rows, cols = np.unravel_index(inds, (self.wave_nrow, self.wave_ncol)) - return regions2d.flatten()[: self.nplots].tolist() + # Generate plot for each wave.chan. Leave unused slots empty. + region_wave_chan: List[List[Region]] = [] + + # The order of (rows, cols) has no effect. + for stereo_nchan, wave_row, wave_col in zip(self.wave_nchans, rows, cols): + # Wave = within Screen. + # Chan = within Wave, generate a plot. + # All arrays are [y, x] == [row, col]. + + # Wave dim/pos (within screen) + waves_per_screen = arr(self.wave_nrow, self.wave_ncol) + wave_screen_pos = arr(wave_row, wave_col) + del wave_row, wave_col + + # Channel dim/pos (within wave) + chans_per_wave = arr(1, 1) # Mutated based on orientation + chan_wave_pos = arr(0, 0) # Mutated in for-chan loop. + + # Distance between chans + dchan = arr(0, 0) # Mutated based on orientation + + if self.stereo_orientation == V: + chans_per_wave[0] = stereo_nchan + dchan[0] = 1 + elif self.stereo_orientation == H: + chans_per_wave[1] = stereo_nchan + dchan[1] = 1 + else: + assert self.stereo_orientation == OVERLAY + + # Channel dim/pos (within screen) + chans_per_screen = chans_per_wave * waves_per_screen + chan_screen_pos = chans_per_wave * wave_screen_pos + + # Generate plots for each channel + region_chan: List[Region] = [] + region_wave_chan.append(region_chan) + region = None + + for chan in range(stereo_nchan): + assert (chan_wave_pos < chans_per_wave).all() + assert (chan_screen_pos < chans_per_screen).all() + + # Generate plot (channel position in screen) + if region is None or dchan.any(): + screen_edges = Edges.at(*chans_per_screen, *chan_screen_pos) + wave_edges = Edges.at(*chans_per_wave, *chan_wave_pos) + + # Removing .copy() causes bugs if region_factory() holds + # mutable references. + chan_spec = RegionSpec( + chans_per_screen.copy(), + chan_screen_pos.copy(), + screen_edges, + wave_edges, + ) + region = region_factory(chan_spec) + region_chan.append(region) + + # Move to next channel position + chan_screen_pos += dchan + chan_wave_pos += dchan + + assert len(region_wave_chan) == self.nwaves + return region_wave_chan -class EdgeFinder(Generic[Region]): - def __init__(self, regions2d: np.ndarray): - self.tops: List[Region] = regions2d[0, :].tolist() - self.bottoms: List[Region] = regions2d[-1, :].tolist() - self.lefts: List[Region] = regions2d[:, 0].tolist() - self.rights: List[Region] = regions2d[:, -1].tolist() +def arr(*args): + return np.array(args) + + +T = TypeVar("T") + + +def unique_by_id(items: Iterable[T]) -> List[T]: + seen = collections.OrderedDict() + + for item in items: + if id(item) not in seen: + seen[id(item)] = item + + return list(seen.values()) diff --git a/corrscope/renderer.py b/corrscope/renderer.py index 08fe21e..54123c8 100644 --- a/corrscope/renderer.py +++ b/corrscope/renderer.py @@ -4,10 +4,17 @@ from typing import Optional, List, TYPE_CHECKING, Any import attr import matplotlib +import matplotlib.colors import numpy as np from corrscope.config import DumpableAttrs, with_units -from corrscope.layout import RendererLayout, LayoutConfig, EdgeFinder +from corrscope.layout import ( + RendererLayout, + LayoutConfig, + unique_by_id, + RegionSpec, + Edges, +) from corrscope.outputs import RGB_DEPTH, ByteBuffer from corrscope.util import coalesce @@ -58,7 +65,10 @@ class RendererConfig(DumpableAttrs, always_dump="*"): bg_color: str = "#000000" init_line_color: str = default_color() + grid_color: Optional[str] = None + stereo_grid_opacity: float = 0.5 + midline_color: Optional[str] = None v_midline: bool = False h_midline: bool = False @@ -99,8 +109,8 @@ class Renderer(ABC): channel_cfgs: Optional[List["ChannelConfig"]], ): self.cfg = cfg + self.lcfg = lcfg self.nplots = nplots - self.layout = RendererLayout(lcfg, nplots) # Load line colors. if channel_cfgs is not None: @@ -165,16 +175,20 @@ class MatplotlibRenderer(Renderer): matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing ) - # Flat array of nrows*ncols elements, ordered by cfg.rows_first. self._fig: "Figure" - self._axes: List["Axes"] # set by set_layout() - self._lines: Optional[List["Line2D"]] = None # set by render_frame() first call - self._set_layout() # mutates self + # _axes2d[wave][chan] = Axes + self._axes2d: List[List["Axes"]] # set by set_layout() + + # _lines2d[wave][chan] = Line2D + self._lines2d: List[List[Line2D]] = [] + self._lines_flat: List["Line2D"] = [] transparent = "#00000000" - def _set_layout(self) -> None: + layout: RendererLayout + + def _set_layout(self, wave_nchans: List[int]) -> None: """ Creates a flat array of Matplotlib Axes, with the new layout. Opens a window showing the Figure (and Axes). @@ -183,6 +197,8 @@ class MatplotlibRenderer(Renderer): Outputs: self.nrows, self.ncols, self.axes """ + self.layout = RendererLayout(self.lcfg, wave_nchans) + # Create Axes # https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html if hasattr(self, "_fig"): @@ -190,24 +206,24 @@ class MatplotlibRenderer(Renderer): # plt.close(self.fig) grid_color = self.cfg.grid_color - axes2d: np.ndarray["Axes"] self._fig = Figure() FigureCanvasAgg(self._fig) - axes2d = self._fig.subplots( - self.layout.nrows, - self.layout.ncols, - squeeze=False, - # Remove axis ticks (which slow down rendering) - subplot_kw=dict(xticks=[], yticks=[]), - # Remove gaps between Axes TODO borders shouldn't be half-visible - gridspec_kw=dict(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0), - ) + # RegionFactory + def axes_factory(r: RegionSpec) -> "Axes": + width = 1 / r.ncol + left = r.col / r.ncol + assert 0 <= left < 1 - ax: "Axes" - if grid_color: - # Initialize borders - for ax in axes2d.flatten(): + height = 1 / r.nrow + bottom = (r.nrow - r.row - 1) / r.nrow + assert 0 <= bottom < 1 + + # Disabling xticks/yticks is unnecessary, since we hide Axises. + ax = self._fig.add_axes([left, bottom, width, height], xticks=[], yticks=[]) + + if grid_color: + # Initialize borders # Hide Axises # (drawing them is very slow, and we disable ticks+labels anyway) ax.get_xaxis().set_visible(False) @@ -223,25 +239,44 @@ class MatplotlibRenderer(Renderer): for spine in ax.spines.values(): spine.set_color(grid_color) - # gridspec_kw indexes from bottom-left corner. - # Only show bottom-left borders (x=0, y=0) - ax.spines["top"].set_visible(False) - ax.spines["right"].set_visible(False) + def hide(key: str): + ax.spines[key].set_visible(False) - # Hide bottom-left edges for speed. - edge_axes: EdgeFinder["Axes"] = EdgeFinder(axes2d) - for ax in edge_axes.bottoms: - ax.spines["bottom"].set_visible(False) - for ax in edge_axes.lefts: - ax.spines["left"].set_visible(False) + # Hide all axes except bottom-right. + hide("top") + hide("left") - else: - # Remove Axis from Axes - for ax in axes2d.flatten(): + # If bottom of screen, hide bottom. If right of screen, hide right. + if r.screen_edges & Edges.Bottom: + hide("bottom") + if r.screen_edges & Edges.Right: + hide("right") + + # Dim stereo gridlines + if self.cfg.stereo_grid_opacity > 0: + dim_color = matplotlib.colors.to_rgba_array(grid_color)[0] + dim_color[-1] = self.cfg.stereo_grid_opacity + + def dim(key: str): + ax.spines[key].set_color(dim_color) + + else: + dim = hide + + # If not bottom of wave, dim bottom. If not right of wave, dim right. + if not r.wave_edges & Edges.Bottom: + dim("bottom") + if not r.wave_edges & Edges.Right: + dim("right") + + else: ax.set_axis_off() - # Generate arrangement (using nplots, cfg.orientation) - self._axes: List[Axes] = self.layout.arrange(lambda row, col: axes2d[row, col]) + return ax + + # Generate arrangement (using self.lcfg, wave_nchans) + # _axes2d[wave][chan] = Axes + self._axes2d = self.layout.arrange(axes_factory) # Setup figure geometry self._fig.set_dpi(DPI) @@ -255,45 +290,66 @@ class MatplotlibRenderer(Renderer): ) # Initialize axes and draw waveform data - if self._lines is None: + if not self._lines2d: + assert len(datas[0].shape) == 2, datas[0].shape + + wave_nchans = [data.shape[1] for data in datas] + self._set_layout(wave_nchans) + cfg = self.cfg # Setup background/axes self._fig.set_facecolor(cfg.bg_color) - for idx, data in enumerate(datas): - ax = self._axes[idx] - max_x = len(data) - 1 - ax.set_xlim(0, max_x) - ax.set_ylim(-1, 1) + for idx, wave_data in enumerate(datas): + wave_axes = self._axes2d[idx] + for ax in unique_by_id(wave_axes): + max_x = len(wave_data) - 1 + ax.set_xlim(0, max_x) + ax.set_ylim(-1, 1) - # Setup midlines - midline_color = cfg.midline_color - midline_width = pixels(1) + # Setup midlines (depends on max_x and wave_data) + midline_color = cfg.midline_color + midline_width = pixels(1) - # zorder=-100 still draws on top of gridlines :( - kw = dict(color=midline_color, linewidth=midline_width) - if cfg.v_midline: - ax.axvline(x=max_x / 2, **kw) - if cfg.h_midline: - ax.axhline(y=0, **kw) + # zorder=-100 still draws on top of gridlines :( + kw = dict(color=midline_color, linewidth=midline_width) + if cfg.v_midline: + ax.axvline(x=max_x / 2, **kw) + if cfg.h_midline: + ax.axhline(y=0, **kw) self._save_background() # Plot lines over background line_width = pixels(cfg.line_width) - self._lines = [] - for idx, data in enumerate(datas): - ax = self._axes[idx] - line_color = self._line_params[idx].color - line = ax.plot(data, color=line_color, linewidth=line_width)[0] - self._lines.append(line) + # Foreach wave + for wave_idx, wave_data in enumerate(datas): + wave_axes = self._axes2d[wave_idx] + wave_lines = [] + + # Foreach chan + for chan_idx, chan_data in enumerate(wave_data.T): + ax = wave_axes[chan_idx] + line_color = self._line_params[wave_idx].color + chan_line: Line2D = ax.plot( + chan_data, color=line_color, linewidth=line_width + )[0] + wave_lines.append(chan_line) + + self._lines2d.append(wave_lines) + self._lines_flat.extend(wave_lines) # Draw waveform data else: - for idx, data in enumerate(datas): - line = self._lines[idx] - line.set_ydata(data) + # Foreach wave + for wave_idx, wave_data in enumerate(datas): + wave_lines = self._lines2d[wave_idx] + + # Foreach chan + for chan_idx, chan_data in enumerate(wave_data.T): + chan_line = wave_lines[chan_idx] + chan_line.set_ydata(chan_data) self._redraw_over_background() @@ -314,8 +370,7 @@ class MatplotlibRenderer(Renderer): canvas: FigureCanvasAgg = self._fig.canvas canvas.restore_region(self.bg_cache) - assert self._lines is not None - for line in self._lines: + for line in self._lines_flat: line.axes.draw_artist(line) # https://bastibe.de/2013-05-30-speeding-up-matplotlib.html diff --git a/corrscope/wave.py b/corrscope/wave.py index 94df619..c500852 100644 --- a/corrscope/wave.py +++ b/corrscope/wave.py @@ -40,7 +40,7 @@ class Wave: __slots__ = """ wave_path amplification - smp_s data _flatten is_mono + smp_s data return_channels _flatten is_mono nsamp dtype center max_val """.split() @@ -91,6 +91,7 @@ class Wave: assert self.data.ndim in [1, 2] self.is_mono = self.data.ndim == 1 self.flatten = flatten + self.return_channels = False # Cast self.data to stereo (nsamp, nchan) if self.is_mono: @@ -130,9 +131,10 @@ class Wave: else: raise CorrError(f"unexpected wavfile dtype {dtype}") - def with_flatten(self, flatten: Flatten) -> "Wave": + def with_flatten(self, flatten: Flatten, return_channels: bool) -> "Wave": new = copy.copy(self) new.flatten = flatten + new.return_channels = return_channels return new def __getitem__(self, index: Union[int, slice]) -> np.ndarray: @@ -154,6 +156,9 @@ class Wave: data -= self.center data *= self.amplification / self.max_val + + if self.return_channels and len(data.shape) == 1: + data = data.reshape(-1, 1) return data def _get(self, begin: int, end: int, subsampling: int) -> np.ndarray: diff --git a/tests/conftest.py b/tests/conftest.py index 07008ca..5616cea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ """ Integration tests found in: - test_cli.py +- test_renderer.py - test_output.py """ diff --git a/tests/test_channel.py b/tests/test_channel.py index e6aa74a..4fa8cc8 100644 --- a/tests/test_channel.py +++ b/tests/test_channel.py @@ -12,6 +12,7 @@ from corrscope.channel import ChannelConfig, Channel from corrscope.corrscope import default_config, CorrScope, BenchmarkMode, Arguments from corrscope.triggers import NullTriggerConfig from corrscope.util import coalesce +from corrscope.wave import Flatten positive = hs.integers(min_value=1, max_value=100) @@ -134,3 +135,34 @@ def test_config_channel_width_stride( # line_color is tested in test_renderer.py + + +@pytest.mark.parametrize("filename", ["tests/sine440.wav", "tests/stereo in-phase.wav"]) +@pytest.mark.parametrize( + ("global_stereo", "chan_stereo"), + [ + [Flatten.SumAvg, None], + [Flatten.Stereo, None], + [Flatten.SumAvg, Flatten.Stereo], + [Flatten.Stereo, Flatten.SumAvg], + ], +) +def test_per_channel_stereo( + filename: str, global_stereo: Flatten, chan_stereo: Optional[Flatten] +): + """Ensure you can enable/disable stereo on a per-channel basis.""" + stereo = coalesce(chan_stereo, global_stereo) + + # Test render wave. + cfg = default_config(render_stereo=global_stereo) + ccfg = ChannelConfig("tests/stereo in-phase.wav", render_stereo=chan_stereo) + channel = Channel(ccfg, cfg) + + # Render wave *must* return stereo. + assert channel.render_wave[0:1].ndim == 2 + data = channel.render_wave.get_around(0, return_nsamp=4, stride=1) + assert data.ndim == 2 + + if "stereo" in filename: + assert channel.render_wave._flatten == stereo + assert data.shape[1] == (2 if stereo is Flatten.Stereo else 1) diff --git a/tests/test_layout.py b/tests/test_layout.py index ee9e060..0a5a05e 100644 --- a/tests/test_layout.py +++ b/tests/test_layout.py @@ -1,9 +1,22 @@ -import numpy as np -import pytest +from typing import List -from corrscope.layout import LayoutConfig, RendererLayout, EdgeFinder +import hypothesis.strategies as hs +import numpy as np +import numpy.testing as npt +import pytest +from hypothesis import given, settings + +from corrscope.layout import ( + LayoutConfig, + RendererLayout, + RegionSpec, + Edges, + Orientation, + StereoOrientation, +) from corrscope.renderer import RendererConfig, MatplotlibRenderer -from tests.test_renderer import WIDTH, HEIGHT +from corrscope.util import ceildiv +from tests.test_renderer import WIDTH, HEIGHT, RENDER_Y_ZEROS def test_layout_config(): @@ -22,44 +35,184 @@ def test_layout_config(): assert default.orientation == "h" +# Small range to ensure many collisions. +# max_value = 3 to allow for edge-free space in center. +integers = hs.integers(-1, 3) + + +@given(nrows=integers, ncols=integers, row=integers, col=integers) +@settings(max_examples=500) +def test_edges(nrows: int, ncols: int, row: int, col: int): + if not (nrows > 0 and ncols > 0 and 0 <= row < nrows and 0 <= col < ncols): + with pytest.raises(ValueError): + edges = Edges.at(nrows, ncols, row, col) + return + + edges = Edges.at(nrows, ncols, row, col) + assert bool(edges & Edges.Left) == (col == 0) + assert bool(edges & Edges.Right) == (col == ncols - 1) + assert bool(edges & Edges.Top) == (row == 0) + assert bool(edges & Edges.Bottom) == (row == nrows - 1) + + @pytest.mark.parametrize("lcfg", [LayoutConfig(ncols=2), LayoutConfig(nrows=8)]) -@pytest.mark.parametrize("region_type", [str, tuple, list]) -def test_hlayout(lcfg, region_type): +def test_hlayout(lcfg): nplots = 15 - layout = RendererLayout(lcfg, nplots) + layout = RendererLayout(lcfg, [1] * nplots) - assert layout.ncols == 2 - assert layout.nrows == 8 + assert layout.wave_ncol == 2 + assert layout.wave_nrow == 8 - regions = layout.arrange(lambda row, col: region_type((row, col))) - assert len(regions) == nplots + region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg) + assert len(region2d) == nplots + for i, regions in enumerate(region2d): + assert len(regions) == 1, (i, len(regions)) + + np.testing.assert_equal(region2d[0][0].pos, (0, 0)) + np.testing.assert_equal(region2d[1][0].pos, (0, 1)) + np.testing.assert_equal(region2d[2][0].pos, (1, 0)) - assert regions[0] == region_type((0, 0)) - assert regions[1] == region_type((0, 1)) - assert regions[2] == region_type((1, 0)) m = nplots - 1 - assert regions[m] == region_type((m // 2, m % 2)) + npt.assert_equal(region2d[m][0].pos, (m // 2, m % 2)) @pytest.mark.parametrize( "lcfg", [LayoutConfig(ncols=3, orientation="v"), LayoutConfig(nrows=3, orientation="v")], ) -@pytest.mark.parametrize("region_type", [str, tuple, list]) -def test_vlayout(lcfg, region_type): +def test_vlayout(lcfg): nplots = 7 - layout = RendererLayout(lcfg, nplots) + layout = RendererLayout(lcfg, [1] * nplots) - assert layout.ncols == 3 - assert layout.nrows == 3 + assert layout.wave_ncol == 3 + assert layout.wave_nrow == 3 - regions = layout.arrange(lambda row, col: region_type((row, col))) - assert len(regions) == nplots + region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg) + assert len(region2d) == nplots + for i, regions in enumerate(region2d): + assert len(regions) == 1, (i, len(regions)) - assert regions[0] == region_type((0, 0)) - assert regions[2] == region_type((2, 0)) - assert regions[3] == region_type((0, 1)) - assert regions[6] == region_type((0, 2)) + np.testing.assert_equal(region2d[0][0].pos, (0, 0)) + np.testing.assert_equal(region2d[2][0].pos, (2, 0)) + np.testing.assert_equal(region2d[3][0].pos, (0, 1)) + np.testing.assert_equal(region2d[6][0].pos, (0, 2)) + + +@given( + wave_nchans=hs.lists(hs.integers(1, 10), min_size=1, max_size=100), + orientation=hs.sampled_from(Orientation), + stereo_orientation=hs.sampled_from(StereoOrientation), + nrow_ncol=hs.integers(1, 100), + is_nrows=hs.booleans(), +) +def test_stereo_layout( + orientation: Orientation, + stereo_orientation: StereoOrientation, + wave_nchans: List[int], + nrow_ncol: int, + is_nrows: bool, +): + """ + Not-entirely-rigorous test for layout computation. + Mind-numbingly boring to write (and read?). + + Honestly I prefer a good naming scheme in RendererLayout.arrange() + over unit tests. + + - This is a regression test... + - And an obstacle to refactoring or feature development. + """ + # region Setup + if is_nrows: + nrows = nrow_ncol + ncols = None + else: + nrows = None + ncols = nrow_ncol + + lcfg = LayoutConfig( + orientation=orientation, + nrows=nrows, + ncols=ncols, + stereo_orientation=stereo_orientation, + ) + nwaves = len(wave_nchans) + layout = RendererLayout(lcfg, wave_nchans) + # endregion + + # Assert layout dimensions correct + assert layout.wave_ncol == ncols or ceildiv(nwaves, nrows) + assert layout.wave_nrow == nrows or ceildiv(nwaves, ncols) + + region2d: List[List[RegionSpec]] = layout.arrange(lambda r_spec: r_spec) + + # Loop through layout regions + assert len(region2d) == len(wave_nchans) + for wave_i, wave_chans in enumerate(region2d): + stereo_nchan = wave_nchans[wave_i] + assert len(wave_chans) == stereo_nchan + + # Compute channel dims within wave. + if stereo_orientation == StereoOrientation.overlay: + chans_per_wave = [1, 1] + elif stereo_orientation == StereoOrientation.v: # pos[0]++ + chans_per_wave = [stereo_nchan, 1] + else: + assert stereo_orientation == StereoOrientation.h # pos[1]++ + chans_per_wave = [1, stereo_nchan] + + # Sanity-check position of channel 0 relative to origin (wave grid). + assert (np.add.reduce(wave_chans[0].pos) != 0) == (wave_i != 0) + npt.assert_equal(wave_chans[0].pos % chans_per_wave, 0) + + for chan_j, chan in enumerate(wave_chans): + # Assert 0 <= position < size. + assert chan.pos.shape == chan.size.shape == (2,) + assert (0 <= chan.pos).all() + assert (chan.pos < chan.size).all() + + # Sanity-check position of chan relative to origin (wave grid). + npt.assert_equal( + chan.pos // chans_per_wave, wave_chans[0].pos // chans_per_wave + ) + + # Check position of region (relative to channel 0) + chan_wave_pos = chan.pos - wave_chans[0].pos + + if stereo_orientation == StereoOrientation.overlay: + npt.assert_equal(chan_wave_pos, [0, 0]) + elif stereo_orientation == StereoOrientation.v: # pos[0]++ + npt.assert_equal(chan_wave_pos, [chan_j, 0]) + else: + assert stereo_orientation == StereoOrientation.h # pos[1]++ + npt.assert_equal(chan_wave_pos, [0, chan_j]) + + # Check screen edges + screen_edges = chan.screen_edges + assert bool(screen_edges & Edges.Top) == (chan.row == 0) + assert bool(screen_edges & Edges.Left) == (chan.col == 0) + assert bool(screen_edges & Edges.Bottom) == (chan.row == chan.nrow - 1) + assert bool(screen_edges & Edges.Right) == (chan.col == chan.ncol - 1) + + # Check stereo edges + wave_edges = chan.wave_edges + if stereo_orientation == StereoOrientation.overlay: + assert wave_edges == ~Edges.NONE + elif stereo_orientation == StereoOrientation.v: # pos[0]++ + lr = Edges.Left | Edges.Right + assert wave_edges & lr == lr + assert bool(wave_edges & Edges.Top) == (chan.row % stereo_nchan == 0) + assert bool(wave_edges & Edges.Bottom) == ( + (chan.row + 1) % stereo_nchan == 0 + ) + else: + assert stereo_orientation == StereoOrientation.h # pos[1]++ + tb = Edges.Top | Edges.Bottom + assert wave_edges & tb == tb + assert bool(wave_edges & Edges.Left) == (chan.col % stereo_nchan == 0) + assert bool(wave_edges & Edges.Right) == ( + (chan.col + 1) % stereo_nchan == 0 + ) def test_renderer_layout(): @@ -69,21 +222,9 @@ def test_renderer_layout(): nplots = 15 r = MatplotlibRenderer(cfg, lcfg, nplots, None) + r.render_frame([RENDER_Y_ZEROS] * nplots) + layout = r.layout # 2 columns, 8 rows - assert r.layout.ncols == 2 - assert r.layout.nrows == 8 - - -# Test EdgeFinder - - -def test_edge_finder(): - regions2d = np.arange(24).reshape(3, 8) - edges = EdgeFinder(regions2d) - - # Check borders - np.testing.assert_equal(edges.lefts, regions2d[:, 0]) - np.testing.assert_equal(edges.rights, regions2d[:, -1]) - np.testing.assert_equal(edges.tops, regions2d[0, :]) - np.testing.assert_equal(edges.bottoms, regions2d[-1, :]) + assert layout.wave_ncol == 2 + assert layout.wave_nrow == 8 diff --git a/tests/test_output.py b/tests/test_output.py index db174f5..3769dff 100644 --- a/tests/test_output.py +++ b/tests/test_output.py @@ -23,11 +23,8 @@ from corrscope.outputs import ( Stop, ) from corrscope.renderer import RendererConfig, MatplotlibRenderer -from corrscope.settings.paths import MissingFFmpegError -from tests.test_renderer import ALL_ZEROS +from tests.test_renderer import RENDER_Y_ZEROS, WIDTH, HEIGHT -WIDTH = 192 -HEIGHT = 108 if TYPE_CHECKING: import pytest_mock @@ -90,7 +87,7 @@ def test_render_output(): renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1, channel_cfgs=None) out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG) - renderer.render_frame([ALL_ZEROS]) + renderer.render_frame([RENDER_Y_ZEROS]) out.write_frame(renderer.get_frame()) assert out.close() == 0 diff --git a/tests/test_renderer.py b/tests/test_renderer.py index cd507cb..733655b 100644 --- a/tests/test_renderer.py +++ b/tests/test_renderer.py @@ -1,38 +1,47 @@ -from typing import Optional +from typing import Optional, TYPE_CHECKING +import matplotlib.colors import numpy as np import pytest -from matplotlib.colors import to_rgb from corrscope.channel import ChannelConfig +from corrscope.corrscope import CorrScope, default_config, Arguments from corrscope.layout import LayoutConfig -from corrscope.outputs import RGB_DEPTH +from corrscope.outputs import RGB_DEPTH, FFplayOutputConfig from corrscope.renderer import RendererConfig, MatplotlibRenderer +from corrscope.wave import Flatten -WIDTH = 640 -HEIGHT = 360 +if TYPE_CHECKING: + import pytest_mock -ALL_ZEROS = np.array([0, 0]) +WIDTH = 64 +HEIGHT = 64 + +RENDER_Y_ZEROS = np.zeros((2, 1)) +RENDER_Y_STEREO = np.zeros((2, 2)) +OPACITY = 2 / 3 all_colors = pytest.mark.parametrize( - "bg_str,fg_str,grid_str", + "bg_str,fg_str,grid_str,data", [ - ("#000000", "#ffffff", None), - ("#ffffff", "#000000", None), - ("#0000aa", "#aaaa00", None), - ("#aaaa00", "#0000aa", None), - # Enabling gridlines enables Axes rectangles. - # Make sure they don't draw *over* the global figure background. - ("#0000aa", "#aaaa00", "#ff00ff"), # beautiful magenta gridlines - ("#aaaa00", "#0000aa", "#ff00ff"), + ("#000000", "#ffffff", None, RENDER_Y_ZEROS), + ("#ffffff", "#000000", None, RENDER_Y_ZEROS), + ("#0000aa", "#aaaa00", None, RENDER_Y_ZEROS), + ("#aaaa00", "#0000aa", None, RENDER_Y_ZEROS), + # Enabling ~~beautiful magenta~~ gridlines enables Axes rectangles. + # Make sure bg is disabled, so they don't overwrite global figure background. + ("#0000aa", "#aaaa00", "#ff00ff", RENDER_Y_ZEROS), + ("#aaaa00", "#0000aa", "#ff00ff", RENDER_Y_ZEROS), + ("#0000aa", "#aaaa00", "#ff00ff", RENDER_Y_STEREO), + ("#aaaa00", "#0000aa", "#ff00ff", RENDER_Y_STEREO), ], ) -nplots = 2 +NPLOTS = 2 @all_colors -def test_default_colors(bg_str, fg_str, grid_str): +def test_default_colors(bg_str, fg_str, grid_str, data): """ Test the default background/foreground colors. """ cfg = RendererConfig( WIDTH, @@ -40,23 +49,24 @@ def test_default_colors(bg_str, fg_str, grid_str): bg_color=bg_str, init_line_color=fg_str, grid_color=grid_str, + stereo_grid_opacity=OPACITY, line_width=2.0, antialiasing=False, ) lcfg = LayoutConfig() - r = MatplotlibRenderer(cfg, lcfg, nplots, None) - verify(r, bg_str, fg_str, grid_str) + r = MatplotlibRenderer(cfg, lcfg, NPLOTS, None) + verify(r, bg_str, fg_str, grid_str, data) # Ensure default ChannelConfig(line_color=None) does not override line color chan = ChannelConfig(wav_path="") - channels = [chan] * nplots - r = MatplotlibRenderer(cfg, lcfg, nplots, channels) - verify(r, bg_str, fg_str, grid_str) + channels = [chan] * NPLOTS + r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels) + verify(r, bg_str, fg_str, grid_str, data) @all_colors -def test_line_colors(bg_str, fg_str, grid_str): +def test_line_colors(bg_str, fg_str, grid_str, data): """ Test channel-specific line color overrides """ cfg = RendererConfig( WIDTH, @@ -64,30 +74,44 @@ def test_line_colors(bg_str, fg_str, grid_str): bg_color=bg_str, init_line_color="#888888", grid_color=grid_str, + stereo_grid_opacity=OPACITY, line_width=2.0, antialiasing=False, ) lcfg = LayoutConfig() chan = ChannelConfig(wav_path="", line_color=fg_str) - channels = [chan] * nplots - r = MatplotlibRenderer(cfg, lcfg, nplots, channels) - verify(r, bg_str, fg_str, grid_str) + channels = [chan] * NPLOTS + r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels) + verify(r, bg_str, fg_str, grid_str, data) -def verify(r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str]): - r.render_frame([ALL_ZEROS] * nplots) +TOLERANCE = 3 + + +def verify( + r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str], data: np.ndarray +): + r.render_frame([data] * NPLOTS) frame_colors: np.ndarray = np.frombuffer(r.get_frame(), dtype=np.uint8).reshape( (-1, RGB_DEPTH) ) - bg_u8 = [round(c * 255) for c in to_rgb(bg_str)] - fg_u8 = [round(c * 255) for c in to_rgb(fg_str)] + bg_u8 = to_rgb(bg_str) + fg_u8 = to_rgb(fg_str) all_colors = [bg_u8, fg_u8] if grid_str: - grid_u8 = [round(c * 255) for c in to_rgb(grid_str)] + grid_u8 = to_rgb(grid_str) all_colors.append(grid_u8) + else: + grid_u8 = bg_u8 + + assert (data.shape[1] > 1) == (data is RENDER_Y_STEREO) + is_stereo = data.shape[1] > 1 + if is_stereo: + stereo_grid_u8 = (grid_u8 * OPACITY + bg_u8 * (1 - OPACITY)).astype(int) + all_colors.append(stereo_grid_u8) # Ensure background is correct bg_frame = frame_colors[0] @@ -104,5 +128,36 @@ def verify(r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str]): if grid_str: assert np.prod(frame_colors == grid_u8, axis=-1).any(), "Missing grid_str" + # Ensure stereo grid color is present + if is_stereo: + assert ( + np.min(np.sum(np.abs(frame_colors - stereo_grid_u8), axis=-1)) < TOLERANCE + ), "Missing stereo gridlines" + assert (np.amax(frame_colors, axis=0) == np.amax(all_colors, axis=0)).all() assert (np.amin(frame_colors, axis=0) == np.amin(all_colors, axis=0)).all() + + +def to_rgb(c) -> np.ndarray: + to_rgb = matplotlib.colors.to_rgb + return np.array([round(c * 255) for c in to_rgb(c)], dtype=int) + + +# Stereo *renderer* integration tests. +def test_stereo_render_integration(mocker: "pytest_mock.MockFixture"): + """Ensure corrscope plays/renders in stereo, without crashing.""" + + # Stub out FFplay output. + mocker.patch.object(FFplayOutputConfig, "cls") + + # Render in stereo. + cfg = default_config( + channels=[ChannelConfig("tests/stereo in-phase.wav")], + render_stereo=Flatten.Stereo, + end_time=0.5, # Reduce test duration + render=RendererConfig(WIDTH, HEIGHT), + ) + + # Make sure it doesn't crash. + corr = CorrScope(cfg, Arguments(".", [FFplayOutputConfig()])) + corr.play() diff --git a/tests/test_wave.py b/tests/test_wave.py index b512add..2680c66 100644 --- a/tests/test_wave.py +++ b/tests/test_wave.py @@ -79,6 +79,7 @@ AllFlattens = Flatten.__members__.values() @pytest.mark.parametrize("flatten", AllFlattens) +@pytest.mark.parametrize("return_channels", [False, True]) @pytest.mark.parametrize( "path,nchan,peaks", [ @@ -88,19 +89,30 @@ AllFlattens = Flatten.__members__.values() ], ) def test_stereo_flatten_modes( - flatten: Flatten, path: str, nchan: int, peaks: Sequence[float] + flatten: Flatten, + return_channels: bool, + path: str, + nchan: int, + peaks: Sequence[float], ): """Ensures all Flatten modes are handled properly for stereo and mono signals.""" + + # return_channels=False <-> triggering. + # flatten=stereo -> rendering. + # These conditions do not currently coexist. + # if not return_channels and flatten == Flatten.Stereo: + # return + assert nchan == len(peaks) wave = Wave(path) if flatten not in Flatten.modes: with pytest.raises(CorrError): - wave.with_flatten(flatten) + wave.with_flatten(flatten, return_channels) return else: - wave = wave.with_flatten(flatten) + wave = wave.with_flatten(flatten, return_channels) nsamp = wave.nsamp data = wave[:] @@ -111,7 +123,10 @@ def test_stereo_flatten_modes( for chan_data, peak in zip(data.T, peaks): assert_full_scale(chan_data, peak) else: - assert data.shape == (nsamp,) + if return_channels: + assert data.shape == (nsamp, 1) + else: + assert data.shape == (nsamp,) # If DiffAvg and in-phase, L-R=0. if flatten == Flatten.DiffAvg: