Add backend for stereo rendering (#196)

GUI support will be added in a separate PR.
pull/357/head
nyanpasu64 2019-02-18 02:10:17 -08:00 zatwierdzone przez GitHub
rodzic d1bdf0cb34
commit 0dbba8ac09
10 zmienionych plików z 652 dodań i 195 usunięć

Wyświetl plik

@ -62,8 +62,8 @@ class Channel:
tflat = coalesce(cfg.trigger_stereo, corr_cfg.trigger_stereo)
rflat = coalesce(cfg.render_stereo, corr_cfg.render_stereo)
self.trigger_wave = wave.with_flatten(tflat)
self.render_wave = wave.with_flatten(rflat)
self.trigger_wave = wave.with_flatten(tflat, return_channels=False)
self.render_wave = wave.with_flatten(rflat, return_channels=True)
# `subsampling` increases `stride` and decreases `nsamp`.
# `width` increases `stride` without changing `nsamp`.

Wyświetl plik

@ -1,16 +1,41 @@
from typing import Optional, TypeVar, Callable, List, Generic, Tuple
import collections
import enum
from enum import auto
from typing import Optional, TypeVar, Callable, List, Iterable
import attr
import numpy as np
from corrscope.config import DumpableAttrs, CorrError
from corrscope.config import DumpableAttrs, CorrError, DumpEnumAsStr
from corrscope.util import ceildiv
class LayoutConfig(DumpableAttrs, always_dump="orientation"):
orientation: str = "h"
class Orientation(str, DumpEnumAsStr):
h = "h"
v = "v"
class StereoOrientation(str, DumpEnumAsStr):
h = "h"
v = "v"
overlay = "overlay"
assert Orientation.h == StereoOrientation.h
H = Orientation.h
V = Orientation.v
OVERLAY = StereoOrientation.overlay
class LayoutConfig(DumpableAttrs, always_dump="orientation stereo_orientation"):
orientation: Orientation = attr.ib(default="h", converter=Orientation)
nrows: Optional[int] = None
ncols: Optional[int] = None
stereo_orientation: StereoOrientation = attr.ib(
default="h", converter=StereoOrientation
)
def __attrs_post_init__(self) -> None:
if not self.nrows:
self.nrows = None
@ -24,31 +49,90 @@ class LayoutConfig(DumpableAttrs, always_dump="orientation"):
self.ncols = 1
class Edges(enum.Flag):
NONE = 0
Top = auto()
Left = auto()
Bottom = auto()
Right = auto()
@staticmethod
def at(nrows: int, ncols: int, row: int, col: int):
if not nrows > 0:
raise ValueError(f"invalid nrows={nrows}, must be positive")
if not ncols > 0:
raise ValueError(f"invalid ncols={ncols}, must be positive")
if not 0 <= row < nrows:
raise ValueError(f"invalid row={row} not in [0 .. nrows={nrows})")
if not 0 <= col < ncols:
raise ValueError(f"invalid col={col} not in [0 .. ncols={ncols})")
ret = Edges.NONE
if row == 0:
ret |= Edges.Top
if row + 1 == nrows:
ret |= Edges.Bottom
if col == 0:
ret |= Edges.Left
if col + 1 == ncols:
ret |= Edges.Right
return ret
def attr_idx_property(key: str, idx: int) -> property:
@property
def inner(self: "RegionSpec"):
return getattr(self, key)[idx]
return inner
@attr.dataclass
class RegionSpec:
"""
- Origin is located at top-left.
- Row 0 = top.
- Row nrows-1 = bottom.
- Col 0 = left.
- Col ncols-1 = right.
"""
size: np.ndarray
pos: np.ndarray
nrow = attr_idx_property("size", 0)
ncol = attr_idx_property("size", 1)
row = attr_idx_property("pos", 0)
col = attr_idx_property("pos", 1)
screen_edges: "Edges"
wave_edges: "Edges"
Region = TypeVar("Region")
RegionFactory = Callable[[int, int], Region] # f(row, column) -> Region
RegionFactory = Callable[[RegionSpec], Region] # f(row, column) -> Region
class RendererLayout:
VALID_ORIENTATIONS = ["h", "v"]
def __init__(self, cfg: LayoutConfig, nplots: int):
def __init__(self, cfg: LayoutConfig, wave_nchans: List[int]):
self.cfg = cfg
self.nplots = nplots
# Setup layout
self.nrows, self.ncols = self._calc_layout()
self.nwaves = len(wave_nchans)
self.wave_nchans = wave_nchans
self.orientation = cfg.orientation
if self.orientation not in self.VALID_ORIENTATIONS:
raise CorrError(
f"Invalid orientation {self.orientation} not in "
f"{self.VALID_ORIENTATIONS}"
)
self.stereo_orientation = cfg.stereo_orientation
def _calc_layout(self) -> Tuple[int, int]:
# Setup layout
self._calc_layout()
# Shape of wave slots
wave_nrow: int
wave_ncol: int
def _calc_layout(self) -> None:
"""
Inputs: self.cfg, self.waves
:return: (nrows, ncols)
Inputs: self.cfg, self.stereo_nchan
Outputs: self.wave_nrow, ncol
"""
cfg = self.cfg
@ -56,7 +140,7 @@ class RendererLayout:
nrows = cfg.nrows
if nrows is None:
raise ValueError("impossible cfg: nrows is None and true")
ncols = ceildiv(self.nplots, nrows)
ncols = ceildiv(self.nwaves, nrows)
else:
if cfg.ncols is None:
raise ValueError(
@ -64,38 +148,110 @@ class RendererLayout:
"(__attrs_post_init__ not called?)"
)
ncols = cfg.ncols
nrows = ceildiv(self.nplots, ncols)
nrows = ceildiv(self.nwaves, ncols)
return nrows, ncols
self.wave_nrow = nrows
self.wave_ncol = ncols
def arrange(self, region_factory: RegionFactory[Region]) -> List[Region]:
""" Generates an array of regions.
index, row, column are fed into region_factory in a row-major order [row][col].
The results are possibly reshaped into column-major order [col][row].
def arrange(self, region_factory: RegionFactory[Region]) -> List[List[Region]]:
"""
nspaces = self.nrows * self.ncols
inds = np.arange(nspaces)
rows, cols = np.unravel_index(inds, (self.nrows, self.ncols))
(row, column) are fed into region_factory in a row-major order [row][col].
Stereo channel pairs are extracted.
The results are possibly reshaped into column-major order [col][row].
row_col = list(zip(rows, cols))
regions = np.empty(len(row_col), dtype=object) # type: np.ndarray[Region]
regions[:] = [region_factory(*rc) for rc in row_col]
:return arr[wave][channel] = Region
"""
regions2d = regions.reshape(
(self.nrows, self.ncols)
) # type: np.ndarray[Region]
wave_spaces = self.wave_nrow * self.wave_ncol
inds = np.arange(wave_spaces)
# if column major:
if self.orientation == "v":
regions2d = regions2d.T
# Compute location of each wave.
if self.orientation == V:
# column major
cols, rows = np.unravel_index(inds, (self.wave_ncol, self.wave_nrow))
else:
# row major
rows, cols = np.unravel_index(inds, (self.wave_nrow, self.wave_ncol))
return regions2d.flatten()[: self.nplots].tolist()
# Generate plot for each wave.chan. Leave unused slots empty.
region_wave_chan: List[List[Region]] = []
# The order of (rows, cols) has no effect.
for stereo_nchan, wave_row, wave_col in zip(self.wave_nchans, rows, cols):
# Wave = within Screen.
# Chan = within Wave, generate a plot.
# All arrays are [y, x] == [row, col].
# Wave dim/pos (within screen)
waves_per_screen = arr(self.wave_nrow, self.wave_ncol)
wave_screen_pos = arr(wave_row, wave_col)
del wave_row, wave_col
# Channel dim/pos (within wave)
chans_per_wave = arr(1, 1) # Mutated based on orientation
chan_wave_pos = arr(0, 0) # Mutated in for-chan loop.
# Distance between chans
dchan = arr(0, 0) # Mutated based on orientation
if self.stereo_orientation == V:
chans_per_wave[0] = stereo_nchan
dchan[0] = 1
elif self.stereo_orientation == H:
chans_per_wave[1] = stereo_nchan
dchan[1] = 1
else:
assert self.stereo_orientation == OVERLAY
# Channel dim/pos (within screen)
chans_per_screen = chans_per_wave * waves_per_screen
chan_screen_pos = chans_per_wave * wave_screen_pos
# Generate plots for each channel
region_chan: List[Region] = []
region_wave_chan.append(region_chan)
region = None
for chan in range(stereo_nchan):
assert (chan_wave_pos < chans_per_wave).all()
assert (chan_screen_pos < chans_per_screen).all()
# Generate plot (channel position in screen)
if region is None or dchan.any():
screen_edges = Edges.at(*chans_per_screen, *chan_screen_pos)
wave_edges = Edges.at(*chans_per_wave, *chan_wave_pos)
# Removing .copy() causes bugs if region_factory() holds
# mutable references.
chan_spec = RegionSpec(
chans_per_screen.copy(),
chan_screen_pos.copy(),
screen_edges,
wave_edges,
)
region = region_factory(chan_spec)
region_chan.append(region)
# Move to next channel position
chan_screen_pos += dchan
chan_wave_pos += dchan
assert len(region_wave_chan) == self.nwaves
return region_wave_chan
class EdgeFinder(Generic[Region]):
def __init__(self, regions2d: np.ndarray):
self.tops: List[Region] = regions2d[0, :].tolist()
self.bottoms: List[Region] = regions2d[-1, :].tolist()
self.lefts: List[Region] = regions2d[:, 0].tolist()
self.rights: List[Region] = regions2d[:, -1].tolist()
def arr(*args):
return np.array(args)
T = TypeVar("T")
def unique_by_id(items: Iterable[T]) -> List[T]:
seen = collections.OrderedDict()
for item in items:
if id(item) not in seen:
seen[id(item)] = item
return list(seen.values())

Wyświetl plik

@ -4,10 +4,17 @@ from typing import Optional, List, TYPE_CHECKING, Any
import attr
import matplotlib
import matplotlib.colors
import numpy as np
from corrscope.config import DumpableAttrs, with_units
from corrscope.layout import RendererLayout, LayoutConfig, EdgeFinder
from corrscope.layout import (
RendererLayout,
LayoutConfig,
unique_by_id,
RegionSpec,
Edges,
)
from corrscope.outputs import RGB_DEPTH, ByteBuffer
from corrscope.util import coalesce
@ -58,7 +65,10 @@ class RendererConfig(DumpableAttrs, always_dump="*"):
bg_color: str = "#000000"
init_line_color: str = default_color()
grid_color: Optional[str] = None
stereo_grid_opacity: float = 0.5
midline_color: Optional[str] = None
v_midline: bool = False
h_midline: bool = False
@ -99,8 +109,8 @@ class Renderer(ABC):
channel_cfgs: Optional[List["ChannelConfig"]],
):
self.cfg = cfg
self.lcfg = lcfg
self.nplots = nplots
self.layout = RendererLayout(lcfg, nplots)
# Load line colors.
if channel_cfgs is not None:
@ -165,16 +175,20 @@ class MatplotlibRenderer(Renderer):
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
)
# Flat array of nrows*ncols elements, ordered by cfg.rows_first.
self._fig: "Figure"
self._axes: List["Axes"] # set by set_layout()
self._lines: Optional[List["Line2D"]] = None # set by render_frame() first call
self._set_layout() # mutates self
# _axes2d[wave][chan] = Axes
self._axes2d: List[List["Axes"]] # set by set_layout()
# _lines2d[wave][chan] = Line2D
self._lines2d: List[List[Line2D]] = []
self._lines_flat: List["Line2D"] = []
transparent = "#00000000"
def _set_layout(self) -> None:
layout: RendererLayout
def _set_layout(self, wave_nchans: List[int]) -> None:
"""
Creates a flat array of Matplotlib Axes, with the new layout.
Opens a window showing the Figure (and Axes).
@ -183,6 +197,8 @@ class MatplotlibRenderer(Renderer):
Outputs: self.nrows, self.ncols, self.axes
"""
self.layout = RendererLayout(self.lcfg, wave_nchans)
# Create Axes
# https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html
if hasattr(self, "_fig"):
@ -190,24 +206,24 @@ class MatplotlibRenderer(Renderer):
# plt.close(self.fig)
grid_color = self.cfg.grid_color
axes2d: np.ndarray["Axes"]
self._fig = Figure()
FigureCanvasAgg(self._fig)
axes2d = self._fig.subplots(
self.layout.nrows,
self.layout.ncols,
squeeze=False,
# Remove axis ticks (which slow down rendering)
subplot_kw=dict(xticks=[], yticks=[]),
# Remove gaps between Axes TODO borders shouldn't be half-visible
gridspec_kw=dict(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0),
)
# RegionFactory
def axes_factory(r: RegionSpec) -> "Axes":
width = 1 / r.ncol
left = r.col / r.ncol
assert 0 <= left < 1
ax: "Axes"
if grid_color:
# Initialize borders
for ax in axes2d.flatten():
height = 1 / r.nrow
bottom = (r.nrow - r.row - 1) / r.nrow
assert 0 <= bottom < 1
# Disabling xticks/yticks is unnecessary, since we hide Axises.
ax = self._fig.add_axes([left, bottom, width, height], xticks=[], yticks=[])
if grid_color:
# Initialize borders
# Hide Axises
# (drawing them is very slow, and we disable ticks+labels anyway)
ax.get_xaxis().set_visible(False)
@ -223,25 +239,44 @@ class MatplotlibRenderer(Renderer):
for spine in ax.spines.values():
spine.set_color(grid_color)
# gridspec_kw indexes from bottom-left corner.
# Only show bottom-left borders (x=0, y=0)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
def hide(key: str):
ax.spines[key].set_visible(False)
# Hide bottom-left edges for speed.
edge_axes: EdgeFinder["Axes"] = EdgeFinder(axes2d)
for ax in edge_axes.bottoms:
ax.spines["bottom"].set_visible(False)
for ax in edge_axes.lefts:
ax.spines["left"].set_visible(False)
# Hide all axes except bottom-right.
hide("top")
hide("left")
else:
# Remove Axis from Axes
for ax in axes2d.flatten():
# If bottom of screen, hide bottom. If right of screen, hide right.
if r.screen_edges & Edges.Bottom:
hide("bottom")
if r.screen_edges & Edges.Right:
hide("right")
# Dim stereo gridlines
if self.cfg.stereo_grid_opacity > 0:
dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
dim_color[-1] = self.cfg.stereo_grid_opacity
def dim(key: str):
ax.spines[key].set_color(dim_color)
else:
dim = hide
# If not bottom of wave, dim bottom. If not right of wave, dim right.
if not r.wave_edges & Edges.Bottom:
dim("bottom")
if not r.wave_edges & Edges.Right:
dim("right")
else:
ax.set_axis_off()
# Generate arrangement (using nplots, cfg.orientation)
self._axes: List[Axes] = self.layout.arrange(lambda row, col: axes2d[row, col])
return ax
# Generate arrangement (using self.lcfg, wave_nchans)
# _axes2d[wave][chan] = Axes
self._axes2d = self.layout.arrange(axes_factory)
# Setup figure geometry
self._fig.set_dpi(DPI)
@ -255,45 +290,66 @@ class MatplotlibRenderer(Renderer):
)
# Initialize axes and draw waveform data
if self._lines is None:
if not self._lines2d:
assert len(datas[0].shape) == 2, datas[0].shape
wave_nchans = [data.shape[1] for data in datas]
self._set_layout(wave_nchans)
cfg = self.cfg
# Setup background/axes
self._fig.set_facecolor(cfg.bg_color)
for idx, data in enumerate(datas):
ax = self._axes[idx]
max_x = len(data) - 1
ax.set_xlim(0, max_x)
ax.set_ylim(-1, 1)
for idx, wave_data in enumerate(datas):
wave_axes = self._axes2d[idx]
for ax in unique_by_id(wave_axes):
max_x = len(wave_data) - 1
ax.set_xlim(0, max_x)
ax.set_ylim(-1, 1)
# Setup midlines
midline_color = cfg.midline_color
midline_width = pixels(1)
# Setup midlines (depends on max_x and wave_data)
midline_color = cfg.midline_color
midline_width = pixels(1)
# zorder=-100 still draws on top of gridlines :(
kw = dict(color=midline_color, linewidth=midline_width)
if cfg.v_midline:
ax.axvline(x=max_x / 2, **kw)
if cfg.h_midline:
ax.axhline(y=0, **kw)
# zorder=-100 still draws on top of gridlines :(
kw = dict(color=midline_color, linewidth=midline_width)
if cfg.v_midline:
ax.axvline(x=max_x / 2, **kw)
if cfg.h_midline:
ax.axhline(y=0, **kw)
self._save_background()
# Plot lines over background
line_width = pixels(cfg.line_width)
self._lines = []
for idx, data in enumerate(datas):
ax = self._axes[idx]
line_color = self._line_params[idx].color
line = ax.plot(data, color=line_color, linewidth=line_width)[0]
self._lines.append(line)
# Foreach wave
for wave_idx, wave_data in enumerate(datas):
wave_axes = self._axes2d[wave_idx]
wave_lines = []
# Foreach chan
for chan_idx, chan_data in enumerate(wave_data.T):
ax = wave_axes[chan_idx]
line_color = self._line_params[wave_idx].color
chan_line: Line2D = ax.plot(
chan_data, color=line_color, linewidth=line_width
)[0]
wave_lines.append(chan_line)
self._lines2d.append(wave_lines)
self._lines_flat.extend(wave_lines)
# Draw waveform data
else:
for idx, data in enumerate(datas):
line = self._lines[idx]
line.set_ydata(data)
# Foreach wave
for wave_idx, wave_data in enumerate(datas):
wave_lines = self._lines2d[wave_idx]
# Foreach chan
for chan_idx, chan_data in enumerate(wave_data.T):
chan_line = wave_lines[chan_idx]
chan_line.set_ydata(chan_data)
self._redraw_over_background()
@ -314,8 +370,7 @@ class MatplotlibRenderer(Renderer):
canvas: FigureCanvasAgg = self._fig.canvas
canvas.restore_region(self.bg_cache)
assert self._lines is not None
for line in self._lines:
for line in self._lines_flat:
line.axes.draw_artist(line)
# https://bastibe.de/2013-05-30-speeding-up-matplotlib.html

Wyświetl plik

@ -40,7 +40,7 @@ class Wave:
__slots__ = """
wave_path
amplification
smp_s data _flatten is_mono
smp_s data return_channels _flatten is_mono
nsamp dtype
center max_val
""".split()
@ -91,6 +91,7 @@ class Wave:
assert self.data.ndim in [1, 2]
self.is_mono = self.data.ndim == 1
self.flatten = flatten
self.return_channels = False
# Cast self.data to stereo (nsamp, nchan)
if self.is_mono:
@ -130,9 +131,10 @@ class Wave:
else:
raise CorrError(f"unexpected wavfile dtype {dtype}")
def with_flatten(self, flatten: Flatten) -> "Wave":
def with_flatten(self, flatten: Flatten, return_channels: bool) -> "Wave":
new = copy.copy(self)
new.flatten = flatten
new.return_channels = return_channels
return new
def __getitem__(self, index: Union[int, slice]) -> np.ndarray:
@ -154,6 +156,9 @@ class Wave:
data -= self.center
data *= self.amplification / self.max_val
if self.return_channels and len(data.shape) == 1:
data = data.reshape(-1, 1)
return data
def _get(self, begin: int, end: int, subsampling: int) -> np.ndarray:

Wyświetl plik

@ -1,6 +1,7 @@
"""
Integration tests found in:
- test_cli.py
- test_renderer.py
- test_output.py
"""

Wyświetl plik

@ -12,6 +12,7 @@ from corrscope.channel import ChannelConfig, Channel
from corrscope.corrscope import default_config, CorrScope, BenchmarkMode, Arguments
from corrscope.triggers import NullTriggerConfig
from corrscope.util import coalesce
from corrscope.wave import Flatten
positive = hs.integers(min_value=1, max_value=100)
@ -134,3 +135,34 @@ def test_config_channel_width_stride(
# line_color is tested in test_renderer.py
@pytest.mark.parametrize("filename", ["tests/sine440.wav", "tests/stereo in-phase.wav"])
@pytest.mark.parametrize(
("global_stereo", "chan_stereo"),
[
[Flatten.SumAvg, None],
[Flatten.Stereo, None],
[Flatten.SumAvg, Flatten.Stereo],
[Flatten.Stereo, Flatten.SumAvg],
],
)
def test_per_channel_stereo(
filename: str, global_stereo: Flatten, chan_stereo: Optional[Flatten]
):
"""Ensure you can enable/disable stereo on a per-channel basis."""
stereo = coalesce(chan_stereo, global_stereo)
# Test render wave.
cfg = default_config(render_stereo=global_stereo)
ccfg = ChannelConfig("tests/stereo in-phase.wav", render_stereo=chan_stereo)
channel = Channel(ccfg, cfg)
# Render wave *must* return stereo.
assert channel.render_wave[0:1].ndim == 2
data = channel.render_wave.get_around(0, return_nsamp=4, stride=1)
assert data.ndim == 2
if "stereo" in filename:
assert channel.render_wave._flatten == stereo
assert data.shape[1] == (2 if stereo is Flatten.Stereo else 1)

Wyświetl plik

@ -1,9 +1,22 @@
import numpy as np
import pytest
from typing import List
from corrscope.layout import LayoutConfig, RendererLayout, EdgeFinder
import hypothesis.strategies as hs
import numpy as np
import numpy.testing as npt
import pytest
from hypothesis import given, settings
from corrscope.layout import (
LayoutConfig,
RendererLayout,
RegionSpec,
Edges,
Orientation,
StereoOrientation,
)
from corrscope.renderer import RendererConfig, MatplotlibRenderer
from tests.test_renderer import WIDTH, HEIGHT
from corrscope.util import ceildiv
from tests.test_renderer import WIDTH, HEIGHT, RENDER_Y_ZEROS
def test_layout_config():
@ -22,44 +35,184 @@ def test_layout_config():
assert default.orientation == "h"
# Small range to ensure many collisions.
# max_value = 3 to allow for edge-free space in center.
integers = hs.integers(-1, 3)
@given(nrows=integers, ncols=integers, row=integers, col=integers)
@settings(max_examples=500)
def test_edges(nrows: int, ncols: int, row: int, col: int):
if not (nrows > 0 and ncols > 0 and 0 <= row < nrows and 0 <= col < ncols):
with pytest.raises(ValueError):
edges = Edges.at(nrows, ncols, row, col)
return
edges = Edges.at(nrows, ncols, row, col)
assert bool(edges & Edges.Left) == (col == 0)
assert bool(edges & Edges.Right) == (col == ncols - 1)
assert bool(edges & Edges.Top) == (row == 0)
assert bool(edges & Edges.Bottom) == (row == nrows - 1)
@pytest.mark.parametrize("lcfg", [LayoutConfig(ncols=2), LayoutConfig(nrows=8)])
@pytest.mark.parametrize("region_type", [str, tuple, list])
def test_hlayout(lcfg, region_type):
def test_hlayout(lcfg):
nplots = 15
layout = RendererLayout(lcfg, nplots)
layout = RendererLayout(lcfg, [1] * nplots)
assert layout.ncols == 2
assert layout.nrows == 8
assert layout.wave_ncol == 2
assert layout.wave_nrow == 8
regions = layout.arrange(lambda row, col: region_type((row, col)))
assert len(regions) == nplots
region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg)
assert len(region2d) == nplots
for i, regions in enumerate(region2d):
assert len(regions) == 1, (i, len(regions))
np.testing.assert_equal(region2d[0][0].pos, (0, 0))
np.testing.assert_equal(region2d[1][0].pos, (0, 1))
np.testing.assert_equal(region2d[2][0].pos, (1, 0))
assert regions[0] == region_type((0, 0))
assert regions[1] == region_type((0, 1))
assert regions[2] == region_type((1, 0))
m = nplots - 1
assert regions[m] == region_type((m // 2, m % 2))
npt.assert_equal(region2d[m][0].pos, (m // 2, m % 2))
@pytest.mark.parametrize(
"lcfg",
[LayoutConfig(ncols=3, orientation="v"), LayoutConfig(nrows=3, orientation="v")],
)
@pytest.mark.parametrize("region_type", [str, tuple, list])
def test_vlayout(lcfg, region_type):
def test_vlayout(lcfg):
nplots = 7
layout = RendererLayout(lcfg, nplots)
layout = RendererLayout(lcfg, [1] * nplots)
assert layout.ncols == 3
assert layout.nrows == 3
assert layout.wave_ncol == 3
assert layout.wave_nrow == 3
regions = layout.arrange(lambda row, col: region_type((row, col)))
assert len(regions) == nplots
region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg)
assert len(region2d) == nplots
for i, regions in enumerate(region2d):
assert len(regions) == 1, (i, len(regions))
assert regions[0] == region_type((0, 0))
assert regions[2] == region_type((2, 0))
assert regions[3] == region_type((0, 1))
assert regions[6] == region_type((0, 2))
np.testing.assert_equal(region2d[0][0].pos, (0, 0))
np.testing.assert_equal(region2d[2][0].pos, (2, 0))
np.testing.assert_equal(region2d[3][0].pos, (0, 1))
np.testing.assert_equal(region2d[6][0].pos, (0, 2))
@given(
wave_nchans=hs.lists(hs.integers(1, 10), min_size=1, max_size=100),
orientation=hs.sampled_from(Orientation),
stereo_orientation=hs.sampled_from(StereoOrientation),
nrow_ncol=hs.integers(1, 100),
is_nrows=hs.booleans(),
)
def test_stereo_layout(
orientation: Orientation,
stereo_orientation: StereoOrientation,
wave_nchans: List[int],
nrow_ncol: int,
is_nrows: bool,
):
"""
Not-entirely-rigorous test for layout computation.
Mind-numbingly boring to write (and read?).
Honestly I prefer a good naming scheme in RendererLayout.arrange()
over unit tests.
- This is a regression test...
- And an obstacle to refactoring or feature development.
"""
# region Setup
if is_nrows:
nrows = nrow_ncol
ncols = None
else:
nrows = None
ncols = nrow_ncol
lcfg = LayoutConfig(
orientation=orientation,
nrows=nrows,
ncols=ncols,
stereo_orientation=stereo_orientation,
)
nwaves = len(wave_nchans)
layout = RendererLayout(lcfg, wave_nchans)
# endregion
# Assert layout dimensions correct
assert layout.wave_ncol == ncols or ceildiv(nwaves, nrows)
assert layout.wave_nrow == nrows or ceildiv(nwaves, ncols)
region2d: List[List[RegionSpec]] = layout.arrange(lambda r_spec: r_spec)
# Loop through layout regions
assert len(region2d) == len(wave_nchans)
for wave_i, wave_chans in enumerate(region2d):
stereo_nchan = wave_nchans[wave_i]
assert len(wave_chans) == stereo_nchan
# Compute channel dims within wave.
if stereo_orientation == StereoOrientation.overlay:
chans_per_wave = [1, 1]
elif stereo_orientation == StereoOrientation.v: # pos[0]++
chans_per_wave = [stereo_nchan, 1]
else:
assert stereo_orientation == StereoOrientation.h # pos[1]++
chans_per_wave = [1, stereo_nchan]
# Sanity-check position of channel 0 relative to origin (wave grid).
assert (np.add.reduce(wave_chans[0].pos) != 0) == (wave_i != 0)
npt.assert_equal(wave_chans[0].pos % chans_per_wave, 0)
for chan_j, chan in enumerate(wave_chans):
# Assert 0 <= position < size.
assert chan.pos.shape == chan.size.shape == (2,)
assert (0 <= chan.pos).all()
assert (chan.pos < chan.size).all()
# Sanity-check position of chan relative to origin (wave grid).
npt.assert_equal(
chan.pos // chans_per_wave, wave_chans[0].pos // chans_per_wave
)
# Check position of region (relative to channel 0)
chan_wave_pos = chan.pos - wave_chans[0].pos
if stereo_orientation == StereoOrientation.overlay:
npt.assert_equal(chan_wave_pos, [0, 0])
elif stereo_orientation == StereoOrientation.v: # pos[0]++
npt.assert_equal(chan_wave_pos, [chan_j, 0])
else:
assert stereo_orientation == StereoOrientation.h # pos[1]++
npt.assert_equal(chan_wave_pos, [0, chan_j])
# Check screen edges
screen_edges = chan.screen_edges
assert bool(screen_edges & Edges.Top) == (chan.row == 0)
assert bool(screen_edges & Edges.Left) == (chan.col == 0)
assert bool(screen_edges & Edges.Bottom) == (chan.row == chan.nrow - 1)
assert bool(screen_edges & Edges.Right) == (chan.col == chan.ncol - 1)
# Check stereo edges
wave_edges = chan.wave_edges
if stereo_orientation == StereoOrientation.overlay:
assert wave_edges == ~Edges.NONE
elif stereo_orientation == StereoOrientation.v: # pos[0]++
lr = Edges.Left | Edges.Right
assert wave_edges & lr == lr
assert bool(wave_edges & Edges.Top) == (chan.row % stereo_nchan == 0)
assert bool(wave_edges & Edges.Bottom) == (
(chan.row + 1) % stereo_nchan == 0
)
else:
assert stereo_orientation == StereoOrientation.h # pos[1]++
tb = Edges.Top | Edges.Bottom
assert wave_edges & tb == tb
assert bool(wave_edges & Edges.Left) == (chan.col % stereo_nchan == 0)
assert bool(wave_edges & Edges.Right) == (
(chan.col + 1) % stereo_nchan == 0
)
def test_renderer_layout():
@ -69,21 +222,9 @@ def test_renderer_layout():
nplots = 15
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
r.render_frame([RENDER_Y_ZEROS] * nplots)
layout = r.layout
# 2 columns, 8 rows
assert r.layout.ncols == 2
assert r.layout.nrows == 8
# Test EdgeFinder
def test_edge_finder():
regions2d = np.arange(24).reshape(3, 8)
edges = EdgeFinder(regions2d)
# Check borders
np.testing.assert_equal(edges.lefts, regions2d[:, 0])
np.testing.assert_equal(edges.rights, regions2d[:, -1])
np.testing.assert_equal(edges.tops, regions2d[0, :])
np.testing.assert_equal(edges.bottoms, regions2d[-1, :])
assert layout.wave_ncol == 2
assert layout.wave_nrow == 8

Wyświetl plik

@ -23,11 +23,8 @@ from corrscope.outputs import (
Stop,
)
from corrscope.renderer import RendererConfig, MatplotlibRenderer
from corrscope.settings.paths import MissingFFmpegError
from tests.test_renderer import ALL_ZEROS
from tests.test_renderer import RENDER_Y_ZEROS, WIDTH, HEIGHT
WIDTH = 192
HEIGHT = 108
if TYPE_CHECKING:
import pytest_mock
@ -90,7 +87,7 @@ def test_render_output():
renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1, channel_cfgs=None)
out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG)
renderer.render_frame([ALL_ZEROS])
renderer.render_frame([RENDER_Y_ZEROS])
out.write_frame(renderer.get_frame())
assert out.close() == 0

Wyświetl plik

@ -1,38 +1,47 @@
from typing import Optional
from typing import Optional, TYPE_CHECKING
import matplotlib.colors
import numpy as np
import pytest
from matplotlib.colors import to_rgb
from corrscope.channel import ChannelConfig
from corrscope.corrscope import CorrScope, default_config, Arguments
from corrscope.layout import LayoutConfig
from corrscope.outputs import RGB_DEPTH
from corrscope.outputs import RGB_DEPTH, FFplayOutputConfig
from corrscope.renderer import RendererConfig, MatplotlibRenderer
from corrscope.wave import Flatten
WIDTH = 640
HEIGHT = 360
if TYPE_CHECKING:
import pytest_mock
ALL_ZEROS = np.array([0, 0])
WIDTH = 64
HEIGHT = 64
RENDER_Y_ZEROS = np.zeros((2, 1))
RENDER_Y_STEREO = np.zeros((2, 2))
OPACITY = 2 / 3
all_colors = pytest.mark.parametrize(
"bg_str,fg_str,grid_str",
"bg_str,fg_str,grid_str,data",
[
("#000000", "#ffffff", None),
("#ffffff", "#000000", None),
("#0000aa", "#aaaa00", None),
("#aaaa00", "#0000aa", None),
# Enabling gridlines enables Axes rectangles.
# Make sure they don't draw *over* the global figure background.
("#0000aa", "#aaaa00", "#ff00ff"), # beautiful magenta gridlines
("#aaaa00", "#0000aa", "#ff00ff"),
("#000000", "#ffffff", None, RENDER_Y_ZEROS),
("#ffffff", "#000000", None, RENDER_Y_ZEROS),
("#0000aa", "#aaaa00", None, RENDER_Y_ZEROS),
("#aaaa00", "#0000aa", None, RENDER_Y_ZEROS),
# Enabling ~~beautiful magenta~~ gridlines enables Axes rectangles.
# Make sure bg is disabled, so they don't overwrite global figure background.
("#0000aa", "#aaaa00", "#ff00ff", RENDER_Y_ZEROS),
("#aaaa00", "#0000aa", "#ff00ff", RENDER_Y_ZEROS),
("#0000aa", "#aaaa00", "#ff00ff", RENDER_Y_STEREO),
("#aaaa00", "#0000aa", "#ff00ff", RENDER_Y_STEREO),
],
)
nplots = 2
NPLOTS = 2
@all_colors
def test_default_colors(bg_str, fg_str, grid_str):
def test_default_colors(bg_str, fg_str, grid_str, data):
""" Test the default background/foreground colors. """
cfg = RendererConfig(
WIDTH,
@ -40,23 +49,24 @@ def test_default_colors(bg_str, fg_str, grid_str):
bg_color=bg_str,
init_line_color=fg_str,
grid_color=grid_str,
stereo_grid_opacity=OPACITY,
line_width=2.0,
antialiasing=False,
)
lcfg = LayoutConfig()
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
verify(r, bg_str, fg_str, grid_str)
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, None)
verify(r, bg_str, fg_str, grid_str, data)
# Ensure default ChannelConfig(line_color=None) does not override line color
chan = ChannelConfig(wav_path="")
channels = [chan] * nplots
r = MatplotlibRenderer(cfg, lcfg, nplots, channels)
verify(r, bg_str, fg_str, grid_str)
channels = [chan] * NPLOTS
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
verify(r, bg_str, fg_str, grid_str, data)
@all_colors
def test_line_colors(bg_str, fg_str, grid_str):
def test_line_colors(bg_str, fg_str, grid_str, data):
""" Test channel-specific line color overrides """
cfg = RendererConfig(
WIDTH,
@ -64,30 +74,44 @@ def test_line_colors(bg_str, fg_str, grid_str):
bg_color=bg_str,
init_line_color="#888888",
grid_color=grid_str,
stereo_grid_opacity=OPACITY,
line_width=2.0,
antialiasing=False,
)
lcfg = LayoutConfig()
chan = ChannelConfig(wav_path="", line_color=fg_str)
channels = [chan] * nplots
r = MatplotlibRenderer(cfg, lcfg, nplots, channels)
verify(r, bg_str, fg_str, grid_str)
channels = [chan] * NPLOTS
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
verify(r, bg_str, fg_str, grid_str, data)
def verify(r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str]):
r.render_frame([ALL_ZEROS] * nplots)
TOLERANCE = 3
def verify(
r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str], data: np.ndarray
):
r.render_frame([data] * NPLOTS)
frame_colors: np.ndarray = np.frombuffer(r.get_frame(), dtype=np.uint8).reshape(
(-1, RGB_DEPTH)
)
bg_u8 = [round(c * 255) for c in to_rgb(bg_str)]
fg_u8 = [round(c * 255) for c in to_rgb(fg_str)]
bg_u8 = to_rgb(bg_str)
fg_u8 = to_rgb(fg_str)
all_colors = [bg_u8, fg_u8]
if grid_str:
grid_u8 = [round(c * 255) for c in to_rgb(grid_str)]
grid_u8 = to_rgb(grid_str)
all_colors.append(grid_u8)
else:
grid_u8 = bg_u8
assert (data.shape[1] > 1) == (data is RENDER_Y_STEREO)
is_stereo = data.shape[1] > 1
if is_stereo:
stereo_grid_u8 = (grid_u8 * OPACITY + bg_u8 * (1 - OPACITY)).astype(int)
all_colors.append(stereo_grid_u8)
# Ensure background is correct
bg_frame = frame_colors[0]
@ -104,5 +128,36 @@ def verify(r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str]):
if grid_str:
assert np.prod(frame_colors == grid_u8, axis=-1).any(), "Missing grid_str"
# Ensure stereo grid color is present
if is_stereo:
assert (
np.min(np.sum(np.abs(frame_colors - stereo_grid_u8), axis=-1)) < TOLERANCE
), "Missing stereo gridlines"
assert (np.amax(frame_colors, axis=0) == np.amax(all_colors, axis=0)).all()
assert (np.amin(frame_colors, axis=0) == np.amin(all_colors, axis=0)).all()
def to_rgb(c) -> np.ndarray:
to_rgb = matplotlib.colors.to_rgb
return np.array([round(c * 255) for c in to_rgb(c)], dtype=int)
# Stereo *renderer* integration tests.
def test_stereo_render_integration(mocker: "pytest_mock.MockFixture"):
"""Ensure corrscope plays/renders in stereo, without crashing."""
# Stub out FFplay output.
mocker.patch.object(FFplayOutputConfig, "cls")
# Render in stereo.
cfg = default_config(
channels=[ChannelConfig("tests/stereo in-phase.wav")],
render_stereo=Flatten.Stereo,
end_time=0.5, # Reduce test duration
render=RendererConfig(WIDTH, HEIGHT),
)
# Make sure it doesn't crash.
corr = CorrScope(cfg, Arguments(".", [FFplayOutputConfig()]))
corr.play()

Wyświetl plik

@ -79,6 +79,7 @@ AllFlattens = Flatten.__members__.values()
@pytest.mark.parametrize("flatten", AllFlattens)
@pytest.mark.parametrize("return_channels", [False, True])
@pytest.mark.parametrize(
"path,nchan,peaks",
[
@ -88,19 +89,30 @@ AllFlattens = Flatten.__members__.values()
],
)
def test_stereo_flatten_modes(
flatten: Flatten, path: str, nchan: int, peaks: Sequence[float]
flatten: Flatten,
return_channels: bool,
path: str,
nchan: int,
peaks: Sequence[float],
):
"""Ensures all Flatten modes are handled properly
for stereo and mono signals."""
# return_channels=False <-> triggering.
# flatten=stereo -> rendering.
# These conditions do not currently coexist.
# if not return_channels and flatten == Flatten.Stereo:
# return
assert nchan == len(peaks)
wave = Wave(path)
if flatten not in Flatten.modes:
with pytest.raises(CorrError):
wave.with_flatten(flatten)
wave.with_flatten(flatten, return_channels)
return
else:
wave = wave.with_flatten(flatten)
wave = wave.with_flatten(flatten, return_channels)
nsamp = wave.nsamp
data = wave[:]
@ -111,7 +123,10 @@ def test_stereo_flatten_modes(
for chan_data, peak in zip(data.T, peaks):
assert_full_scale(chan_data, peak)
else:
assert data.shape == (nsamp,)
if return_channels:
assert data.shape == (nsamp, 1)
else:
assert data.shape == (nsamp,)
# If DiffAvg and in-phase, L-R=0.
if flatten == Flatten.DiffAvg: