Add backend for stereo rendering (#196)

GUI support will be added in a separate PR.
pull/357/head
nyanpasu64 2019-02-18 02:10:17 -08:00 zatwierdzone przez GitHub
rodzic d1bdf0cb34
commit 0dbba8ac09
10 zmienionych plików z 652 dodań i 195 usunięć

Wyświetl plik

@ -62,8 +62,8 @@ class Channel:
tflat = coalesce(cfg.trigger_stereo, corr_cfg.trigger_stereo) tflat = coalesce(cfg.trigger_stereo, corr_cfg.trigger_stereo)
rflat = coalesce(cfg.render_stereo, corr_cfg.render_stereo) rflat = coalesce(cfg.render_stereo, corr_cfg.render_stereo)
self.trigger_wave = wave.with_flatten(tflat) self.trigger_wave = wave.with_flatten(tflat, return_channels=False)
self.render_wave = wave.with_flatten(rflat) self.render_wave = wave.with_flatten(rflat, return_channels=True)
# `subsampling` increases `stride` and decreases `nsamp`. # `subsampling` increases `stride` and decreases `nsamp`.
# `width` increases `stride` without changing `nsamp`. # `width` increases `stride` without changing `nsamp`.

Wyświetl plik

@ -1,16 +1,41 @@
from typing import Optional, TypeVar, Callable, List, Generic, Tuple import collections
import enum
from enum import auto
from typing import Optional, TypeVar, Callable, List, Iterable
import attr
import numpy as np import numpy as np
from corrscope.config import DumpableAttrs, CorrError from corrscope.config import DumpableAttrs, CorrError, DumpEnumAsStr
from corrscope.util import ceildiv from corrscope.util import ceildiv
class LayoutConfig(DumpableAttrs, always_dump="orientation"): class Orientation(str, DumpEnumAsStr):
orientation: str = "h" h = "h"
v = "v"
class StereoOrientation(str, DumpEnumAsStr):
h = "h"
v = "v"
overlay = "overlay"
assert Orientation.h == StereoOrientation.h
H = Orientation.h
V = Orientation.v
OVERLAY = StereoOrientation.overlay
class LayoutConfig(DumpableAttrs, always_dump="orientation stereo_orientation"):
orientation: Orientation = attr.ib(default="h", converter=Orientation)
nrows: Optional[int] = None nrows: Optional[int] = None
ncols: Optional[int] = None ncols: Optional[int] = None
stereo_orientation: StereoOrientation = attr.ib(
default="h", converter=StereoOrientation
)
def __attrs_post_init__(self) -> None: def __attrs_post_init__(self) -> None:
if not self.nrows: if not self.nrows:
self.nrows = None self.nrows = None
@ -24,31 +49,90 @@ class LayoutConfig(DumpableAttrs, always_dump="orientation"):
self.ncols = 1 self.ncols = 1
class Edges(enum.Flag):
NONE = 0
Top = auto()
Left = auto()
Bottom = auto()
Right = auto()
@staticmethod
def at(nrows: int, ncols: int, row: int, col: int):
if not nrows > 0:
raise ValueError(f"invalid nrows={nrows}, must be positive")
if not ncols > 0:
raise ValueError(f"invalid ncols={ncols}, must be positive")
if not 0 <= row < nrows:
raise ValueError(f"invalid row={row} not in [0 .. nrows={nrows})")
if not 0 <= col < ncols:
raise ValueError(f"invalid col={col} not in [0 .. ncols={ncols})")
ret = Edges.NONE
if row == 0:
ret |= Edges.Top
if row + 1 == nrows:
ret |= Edges.Bottom
if col == 0:
ret |= Edges.Left
if col + 1 == ncols:
ret |= Edges.Right
return ret
def attr_idx_property(key: str, idx: int) -> property:
@property
def inner(self: "RegionSpec"):
return getattr(self, key)[idx]
return inner
@attr.dataclass
class RegionSpec:
"""
- Origin is located at top-left.
- Row 0 = top.
- Row nrows-1 = bottom.
- Col 0 = left.
- Col ncols-1 = right.
"""
size: np.ndarray
pos: np.ndarray
nrow = attr_idx_property("size", 0)
ncol = attr_idx_property("size", 1)
row = attr_idx_property("pos", 0)
col = attr_idx_property("pos", 1)
screen_edges: "Edges"
wave_edges: "Edges"
Region = TypeVar("Region") Region = TypeVar("Region")
RegionFactory = Callable[[int, int], Region] # f(row, column) -> Region RegionFactory = Callable[[RegionSpec], Region] # f(row, column) -> Region
class RendererLayout: class RendererLayout:
VALID_ORIENTATIONS = ["h", "v"] def __init__(self, cfg: LayoutConfig, wave_nchans: List[int]):
def __init__(self, cfg: LayoutConfig, nplots: int):
self.cfg = cfg self.cfg = cfg
self.nplots = nplots self.nwaves = len(wave_nchans)
self.wave_nchans = wave_nchans
# Setup layout
self.nrows, self.ncols = self._calc_layout()
self.orientation = cfg.orientation self.orientation = cfg.orientation
if self.orientation not in self.VALID_ORIENTATIONS: self.stereo_orientation = cfg.stereo_orientation
raise CorrError(
f"Invalid orientation {self.orientation} not in "
f"{self.VALID_ORIENTATIONS}"
)
def _calc_layout(self) -> Tuple[int, int]: # Setup layout
self._calc_layout()
# Shape of wave slots
wave_nrow: int
wave_ncol: int
def _calc_layout(self) -> None:
""" """
Inputs: self.cfg, self.waves Inputs: self.cfg, self.stereo_nchan
:return: (nrows, ncols) Outputs: self.wave_nrow, ncol
""" """
cfg = self.cfg cfg = self.cfg
@ -56,7 +140,7 @@ class RendererLayout:
nrows = cfg.nrows nrows = cfg.nrows
if nrows is None: if nrows is None:
raise ValueError("impossible cfg: nrows is None and true") raise ValueError("impossible cfg: nrows is None and true")
ncols = ceildiv(self.nplots, nrows) ncols = ceildiv(self.nwaves, nrows)
else: else:
if cfg.ncols is None: if cfg.ncols is None:
raise ValueError( raise ValueError(
@ -64,38 +148,110 @@ class RendererLayout:
"(__attrs_post_init__ not called?)" "(__attrs_post_init__ not called?)"
) )
ncols = cfg.ncols ncols = cfg.ncols
nrows = ceildiv(self.nplots, ncols) nrows = ceildiv(self.nwaves, ncols)
return nrows, ncols self.wave_nrow = nrows
self.wave_ncol = ncols
def arrange(self, region_factory: RegionFactory[Region]) -> List[Region]: def arrange(self, region_factory: RegionFactory[Region]) -> List[List[Region]]:
""" Generates an array of regions.
index, row, column are fed into region_factory in a row-major order [row][col].
The results are possibly reshaped into column-major order [col][row].
""" """
nspaces = self.nrows * self.ncols (row, column) are fed into region_factory in a row-major order [row][col].
inds = np.arange(nspaces) Stereo channel pairs are extracted.
rows, cols = np.unravel_index(inds, (self.nrows, self.ncols)) The results are possibly reshaped into column-major order [col][row].
row_col = list(zip(rows, cols)) :return arr[wave][channel] = Region
regions = np.empty(len(row_col), dtype=object) # type: np.ndarray[Region] """
regions[:] = [region_factory(*rc) for rc in row_col]
regions2d = regions.reshape( wave_spaces = self.wave_nrow * self.wave_ncol
(self.nrows, self.ncols) inds = np.arange(wave_spaces)
) # type: np.ndarray[Region]
# if column major: # Compute location of each wave.
if self.orientation == "v": if self.orientation == V:
regions2d = regions2d.T # column major
cols, rows = np.unravel_index(inds, (self.wave_ncol, self.wave_nrow))
else:
# row major
rows, cols = np.unravel_index(inds, (self.wave_nrow, self.wave_ncol))
return regions2d.flatten()[: self.nplots].tolist() # Generate plot for each wave.chan. Leave unused slots empty.
region_wave_chan: List[List[Region]] = []
# The order of (rows, cols) has no effect.
for stereo_nchan, wave_row, wave_col in zip(self.wave_nchans, rows, cols):
# Wave = within Screen.
# Chan = within Wave, generate a plot.
# All arrays are [y, x] == [row, col].
# Wave dim/pos (within screen)
waves_per_screen = arr(self.wave_nrow, self.wave_ncol)
wave_screen_pos = arr(wave_row, wave_col)
del wave_row, wave_col
# Channel dim/pos (within wave)
chans_per_wave = arr(1, 1) # Mutated based on orientation
chan_wave_pos = arr(0, 0) # Mutated in for-chan loop.
# Distance between chans
dchan = arr(0, 0) # Mutated based on orientation
if self.stereo_orientation == V:
chans_per_wave[0] = stereo_nchan
dchan[0] = 1
elif self.stereo_orientation == H:
chans_per_wave[1] = stereo_nchan
dchan[1] = 1
else:
assert self.stereo_orientation == OVERLAY
# Channel dim/pos (within screen)
chans_per_screen = chans_per_wave * waves_per_screen
chan_screen_pos = chans_per_wave * wave_screen_pos
# Generate plots for each channel
region_chan: List[Region] = []
region_wave_chan.append(region_chan)
region = None
for chan in range(stereo_nchan):
assert (chan_wave_pos < chans_per_wave).all()
assert (chan_screen_pos < chans_per_screen).all()
# Generate plot (channel position in screen)
if region is None or dchan.any():
screen_edges = Edges.at(*chans_per_screen, *chan_screen_pos)
wave_edges = Edges.at(*chans_per_wave, *chan_wave_pos)
# Removing .copy() causes bugs if region_factory() holds
# mutable references.
chan_spec = RegionSpec(
chans_per_screen.copy(),
chan_screen_pos.copy(),
screen_edges,
wave_edges,
)
region = region_factory(chan_spec)
region_chan.append(region)
# Move to next channel position
chan_screen_pos += dchan
chan_wave_pos += dchan
assert len(region_wave_chan) == self.nwaves
return region_wave_chan
class EdgeFinder(Generic[Region]): def arr(*args):
def __init__(self, regions2d: np.ndarray): return np.array(args)
self.tops: List[Region] = regions2d[0, :].tolist()
self.bottoms: List[Region] = regions2d[-1, :].tolist()
self.lefts: List[Region] = regions2d[:, 0].tolist() T = TypeVar("T")
self.rights: List[Region] = regions2d[:, -1].tolist()
def unique_by_id(items: Iterable[T]) -> List[T]:
seen = collections.OrderedDict()
for item in items:
if id(item) not in seen:
seen[id(item)] = item
return list(seen.values())

Wyświetl plik

@ -4,10 +4,17 @@ from typing import Optional, List, TYPE_CHECKING, Any
import attr import attr
import matplotlib import matplotlib
import matplotlib.colors
import numpy as np import numpy as np
from corrscope.config import DumpableAttrs, with_units from corrscope.config import DumpableAttrs, with_units
from corrscope.layout import RendererLayout, LayoutConfig, EdgeFinder from corrscope.layout import (
RendererLayout,
LayoutConfig,
unique_by_id,
RegionSpec,
Edges,
)
from corrscope.outputs import RGB_DEPTH, ByteBuffer from corrscope.outputs import RGB_DEPTH, ByteBuffer
from corrscope.util import coalesce from corrscope.util import coalesce
@ -58,7 +65,10 @@ class RendererConfig(DumpableAttrs, always_dump="*"):
bg_color: str = "#000000" bg_color: str = "#000000"
init_line_color: str = default_color() init_line_color: str = default_color()
grid_color: Optional[str] = None grid_color: Optional[str] = None
stereo_grid_opacity: float = 0.5
midline_color: Optional[str] = None midline_color: Optional[str] = None
v_midline: bool = False v_midline: bool = False
h_midline: bool = False h_midline: bool = False
@ -99,8 +109,8 @@ class Renderer(ABC):
channel_cfgs: Optional[List["ChannelConfig"]], channel_cfgs: Optional[List["ChannelConfig"]],
): ):
self.cfg = cfg self.cfg = cfg
self.lcfg = lcfg
self.nplots = nplots self.nplots = nplots
self.layout = RendererLayout(lcfg, nplots)
# Load line colors. # Load line colors.
if channel_cfgs is not None: if channel_cfgs is not None:
@ -165,16 +175,20 @@ class MatplotlibRenderer(Renderer):
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
) )
# Flat array of nrows*ncols elements, ordered by cfg.rows_first.
self._fig: "Figure" self._fig: "Figure"
self._axes: List["Axes"] # set by set_layout()
self._lines: Optional[List["Line2D"]] = None # set by render_frame() first call
self._set_layout() # mutates self # _axes2d[wave][chan] = Axes
self._axes2d: List[List["Axes"]] # set by set_layout()
# _lines2d[wave][chan] = Line2D
self._lines2d: List[List[Line2D]] = []
self._lines_flat: List["Line2D"] = []
transparent = "#00000000" transparent = "#00000000"
def _set_layout(self) -> None: layout: RendererLayout
def _set_layout(self, wave_nchans: List[int]) -> None:
""" """
Creates a flat array of Matplotlib Axes, with the new layout. Creates a flat array of Matplotlib Axes, with the new layout.
Opens a window showing the Figure (and Axes). Opens a window showing the Figure (and Axes).
@ -183,6 +197,8 @@ class MatplotlibRenderer(Renderer):
Outputs: self.nrows, self.ncols, self.axes Outputs: self.nrows, self.ncols, self.axes
""" """
self.layout = RendererLayout(self.lcfg, wave_nchans)
# Create Axes # Create Axes
# https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html # https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html
if hasattr(self, "_fig"): if hasattr(self, "_fig"):
@ -190,24 +206,24 @@ class MatplotlibRenderer(Renderer):
# plt.close(self.fig) # plt.close(self.fig)
grid_color = self.cfg.grid_color grid_color = self.cfg.grid_color
axes2d: np.ndarray["Axes"]
self._fig = Figure() self._fig = Figure()
FigureCanvasAgg(self._fig) FigureCanvasAgg(self._fig)
axes2d = self._fig.subplots( # RegionFactory
self.layout.nrows, def axes_factory(r: RegionSpec) -> "Axes":
self.layout.ncols, width = 1 / r.ncol
squeeze=False, left = r.col / r.ncol
# Remove axis ticks (which slow down rendering) assert 0 <= left < 1
subplot_kw=dict(xticks=[], yticks=[]),
# Remove gaps between Axes TODO borders shouldn't be half-visible
gridspec_kw=dict(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0),
)
ax: "Axes" height = 1 / r.nrow
if grid_color: bottom = (r.nrow - r.row - 1) / r.nrow
# Initialize borders assert 0 <= bottom < 1
for ax in axes2d.flatten():
# Disabling xticks/yticks is unnecessary, since we hide Axises.
ax = self._fig.add_axes([left, bottom, width, height], xticks=[], yticks=[])
if grid_color:
# Initialize borders
# Hide Axises # Hide Axises
# (drawing them is very slow, and we disable ticks+labels anyway) # (drawing them is very slow, and we disable ticks+labels anyway)
ax.get_xaxis().set_visible(False) ax.get_xaxis().set_visible(False)
@ -223,25 +239,44 @@ class MatplotlibRenderer(Renderer):
for spine in ax.spines.values(): for spine in ax.spines.values():
spine.set_color(grid_color) spine.set_color(grid_color)
# gridspec_kw indexes from bottom-left corner. def hide(key: str):
# Only show bottom-left borders (x=0, y=0) ax.spines[key].set_visible(False)
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
# Hide bottom-left edges for speed. # Hide all axes except bottom-right.
edge_axes: EdgeFinder["Axes"] = EdgeFinder(axes2d) hide("top")
for ax in edge_axes.bottoms: hide("left")
ax.spines["bottom"].set_visible(False)
for ax in edge_axes.lefts:
ax.spines["left"].set_visible(False)
else: # If bottom of screen, hide bottom. If right of screen, hide right.
# Remove Axis from Axes if r.screen_edges & Edges.Bottom:
for ax in axes2d.flatten(): hide("bottom")
if r.screen_edges & Edges.Right:
hide("right")
# Dim stereo gridlines
if self.cfg.stereo_grid_opacity > 0:
dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
dim_color[-1] = self.cfg.stereo_grid_opacity
def dim(key: str):
ax.spines[key].set_color(dim_color)
else:
dim = hide
# If not bottom of wave, dim bottom. If not right of wave, dim right.
if not r.wave_edges & Edges.Bottom:
dim("bottom")
if not r.wave_edges & Edges.Right:
dim("right")
else:
ax.set_axis_off() ax.set_axis_off()
# Generate arrangement (using nplots, cfg.orientation) return ax
self._axes: List[Axes] = self.layout.arrange(lambda row, col: axes2d[row, col])
# Generate arrangement (using self.lcfg, wave_nchans)
# _axes2d[wave][chan] = Axes
self._axes2d = self.layout.arrange(axes_factory)
# Setup figure geometry # Setup figure geometry
self._fig.set_dpi(DPI) self._fig.set_dpi(DPI)
@ -255,45 +290,66 @@ class MatplotlibRenderer(Renderer):
) )
# Initialize axes and draw waveform data # Initialize axes and draw waveform data
if self._lines is None: if not self._lines2d:
assert len(datas[0].shape) == 2, datas[0].shape
wave_nchans = [data.shape[1] for data in datas]
self._set_layout(wave_nchans)
cfg = self.cfg cfg = self.cfg
# Setup background/axes # Setup background/axes
self._fig.set_facecolor(cfg.bg_color) self._fig.set_facecolor(cfg.bg_color)
for idx, data in enumerate(datas): for idx, wave_data in enumerate(datas):
ax = self._axes[idx] wave_axes = self._axes2d[idx]
max_x = len(data) - 1 for ax in unique_by_id(wave_axes):
ax.set_xlim(0, max_x) max_x = len(wave_data) - 1
ax.set_ylim(-1, 1) ax.set_xlim(0, max_x)
ax.set_ylim(-1, 1)
# Setup midlines # Setup midlines (depends on max_x and wave_data)
midline_color = cfg.midline_color midline_color = cfg.midline_color
midline_width = pixels(1) midline_width = pixels(1)
# zorder=-100 still draws on top of gridlines :( # zorder=-100 still draws on top of gridlines :(
kw = dict(color=midline_color, linewidth=midline_width) kw = dict(color=midline_color, linewidth=midline_width)
if cfg.v_midline: if cfg.v_midline:
ax.axvline(x=max_x / 2, **kw) ax.axvline(x=max_x / 2, **kw)
if cfg.h_midline: if cfg.h_midline:
ax.axhline(y=0, **kw) ax.axhline(y=0, **kw)
self._save_background() self._save_background()
# Plot lines over background # Plot lines over background
line_width = pixels(cfg.line_width) line_width = pixels(cfg.line_width)
self._lines = []
for idx, data in enumerate(datas): # Foreach wave
ax = self._axes[idx] for wave_idx, wave_data in enumerate(datas):
line_color = self._line_params[idx].color wave_axes = self._axes2d[wave_idx]
line = ax.plot(data, color=line_color, linewidth=line_width)[0] wave_lines = []
self._lines.append(line)
# Foreach chan
for chan_idx, chan_data in enumerate(wave_data.T):
ax = wave_axes[chan_idx]
line_color = self._line_params[wave_idx].color
chan_line: Line2D = ax.plot(
chan_data, color=line_color, linewidth=line_width
)[0]
wave_lines.append(chan_line)
self._lines2d.append(wave_lines)
self._lines_flat.extend(wave_lines)
# Draw waveform data # Draw waveform data
else: else:
for idx, data in enumerate(datas): # Foreach wave
line = self._lines[idx] for wave_idx, wave_data in enumerate(datas):
line.set_ydata(data) wave_lines = self._lines2d[wave_idx]
# Foreach chan
for chan_idx, chan_data in enumerate(wave_data.T):
chan_line = wave_lines[chan_idx]
chan_line.set_ydata(chan_data)
self._redraw_over_background() self._redraw_over_background()
@ -314,8 +370,7 @@ class MatplotlibRenderer(Renderer):
canvas: FigureCanvasAgg = self._fig.canvas canvas: FigureCanvasAgg = self._fig.canvas
canvas.restore_region(self.bg_cache) canvas.restore_region(self.bg_cache)
assert self._lines is not None for line in self._lines_flat:
for line in self._lines:
line.axes.draw_artist(line) line.axes.draw_artist(line)
# https://bastibe.de/2013-05-30-speeding-up-matplotlib.html # https://bastibe.de/2013-05-30-speeding-up-matplotlib.html

Wyświetl plik

@ -40,7 +40,7 @@ class Wave:
__slots__ = """ __slots__ = """
wave_path wave_path
amplification amplification
smp_s data _flatten is_mono smp_s data return_channels _flatten is_mono
nsamp dtype nsamp dtype
center max_val center max_val
""".split() """.split()
@ -91,6 +91,7 @@ class Wave:
assert self.data.ndim in [1, 2] assert self.data.ndim in [1, 2]
self.is_mono = self.data.ndim == 1 self.is_mono = self.data.ndim == 1
self.flatten = flatten self.flatten = flatten
self.return_channels = False
# Cast self.data to stereo (nsamp, nchan) # Cast self.data to stereo (nsamp, nchan)
if self.is_mono: if self.is_mono:
@ -130,9 +131,10 @@ class Wave:
else: else:
raise CorrError(f"unexpected wavfile dtype {dtype}") raise CorrError(f"unexpected wavfile dtype {dtype}")
def with_flatten(self, flatten: Flatten) -> "Wave": def with_flatten(self, flatten: Flatten, return_channels: bool) -> "Wave":
new = copy.copy(self) new = copy.copy(self)
new.flatten = flatten new.flatten = flatten
new.return_channels = return_channels
return new return new
def __getitem__(self, index: Union[int, slice]) -> np.ndarray: def __getitem__(self, index: Union[int, slice]) -> np.ndarray:
@ -154,6 +156,9 @@ class Wave:
data -= self.center data -= self.center
data *= self.amplification / self.max_val data *= self.amplification / self.max_val
if self.return_channels and len(data.shape) == 1:
data = data.reshape(-1, 1)
return data return data
def _get(self, begin: int, end: int, subsampling: int) -> np.ndarray: def _get(self, begin: int, end: int, subsampling: int) -> np.ndarray:

Wyświetl plik

@ -1,6 +1,7 @@
""" """
Integration tests found in: Integration tests found in:
- test_cli.py - test_cli.py
- test_renderer.py
- test_output.py - test_output.py
""" """

Wyświetl plik

@ -12,6 +12,7 @@ from corrscope.channel import ChannelConfig, Channel
from corrscope.corrscope import default_config, CorrScope, BenchmarkMode, Arguments from corrscope.corrscope import default_config, CorrScope, BenchmarkMode, Arguments
from corrscope.triggers import NullTriggerConfig from corrscope.triggers import NullTriggerConfig
from corrscope.util import coalesce from corrscope.util import coalesce
from corrscope.wave import Flatten
positive = hs.integers(min_value=1, max_value=100) positive = hs.integers(min_value=1, max_value=100)
@ -134,3 +135,34 @@ def test_config_channel_width_stride(
# line_color is tested in test_renderer.py # line_color is tested in test_renderer.py
@pytest.mark.parametrize("filename", ["tests/sine440.wav", "tests/stereo in-phase.wav"])
@pytest.mark.parametrize(
("global_stereo", "chan_stereo"),
[
[Flatten.SumAvg, None],
[Flatten.Stereo, None],
[Flatten.SumAvg, Flatten.Stereo],
[Flatten.Stereo, Flatten.SumAvg],
],
)
def test_per_channel_stereo(
filename: str, global_stereo: Flatten, chan_stereo: Optional[Flatten]
):
"""Ensure you can enable/disable stereo on a per-channel basis."""
stereo = coalesce(chan_stereo, global_stereo)
# Test render wave.
cfg = default_config(render_stereo=global_stereo)
ccfg = ChannelConfig("tests/stereo in-phase.wav", render_stereo=chan_stereo)
channel = Channel(ccfg, cfg)
# Render wave *must* return stereo.
assert channel.render_wave[0:1].ndim == 2
data = channel.render_wave.get_around(0, return_nsamp=4, stride=1)
assert data.ndim == 2
if "stereo" in filename:
assert channel.render_wave._flatten == stereo
assert data.shape[1] == (2 if stereo is Flatten.Stereo else 1)

Wyświetl plik

@ -1,9 +1,22 @@
import numpy as np from typing import List
import pytest
from corrscope.layout import LayoutConfig, RendererLayout, EdgeFinder import hypothesis.strategies as hs
import numpy as np
import numpy.testing as npt
import pytest
from hypothesis import given, settings
from corrscope.layout import (
LayoutConfig,
RendererLayout,
RegionSpec,
Edges,
Orientation,
StereoOrientation,
)
from corrscope.renderer import RendererConfig, MatplotlibRenderer from corrscope.renderer import RendererConfig, MatplotlibRenderer
from tests.test_renderer import WIDTH, HEIGHT from corrscope.util import ceildiv
from tests.test_renderer import WIDTH, HEIGHT, RENDER_Y_ZEROS
def test_layout_config(): def test_layout_config():
@ -22,44 +35,184 @@ def test_layout_config():
assert default.orientation == "h" assert default.orientation == "h"
# Small range to ensure many collisions.
# max_value = 3 to allow for edge-free space in center.
integers = hs.integers(-1, 3)
@given(nrows=integers, ncols=integers, row=integers, col=integers)
@settings(max_examples=500)
def test_edges(nrows: int, ncols: int, row: int, col: int):
if not (nrows > 0 and ncols > 0 and 0 <= row < nrows and 0 <= col < ncols):
with pytest.raises(ValueError):
edges = Edges.at(nrows, ncols, row, col)
return
edges = Edges.at(nrows, ncols, row, col)
assert bool(edges & Edges.Left) == (col == 0)
assert bool(edges & Edges.Right) == (col == ncols - 1)
assert bool(edges & Edges.Top) == (row == 0)
assert bool(edges & Edges.Bottom) == (row == nrows - 1)
@pytest.mark.parametrize("lcfg", [LayoutConfig(ncols=2), LayoutConfig(nrows=8)]) @pytest.mark.parametrize("lcfg", [LayoutConfig(ncols=2), LayoutConfig(nrows=8)])
@pytest.mark.parametrize("region_type", [str, tuple, list]) def test_hlayout(lcfg):
def test_hlayout(lcfg, region_type):
nplots = 15 nplots = 15
layout = RendererLayout(lcfg, nplots) layout = RendererLayout(lcfg, [1] * nplots)
assert layout.ncols == 2 assert layout.wave_ncol == 2
assert layout.nrows == 8 assert layout.wave_nrow == 8
regions = layout.arrange(lambda row, col: region_type((row, col))) region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg)
assert len(regions) == nplots assert len(region2d) == nplots
for i, regions in enumerate(region2d):
assert len(regions) == 1, (i, len(regions))
np.testing.assert_equal(region2d[0][0].pos, (0, 0))
np.testing.assert_equal(region2d[1][0].pos, (0, 1))
np.testing.assert_equal(region2d[2][0].pos, (1, 0))
assert regions[0] == region_type((0, 0))
assert regions[1] == region_type((0, 1))
assert regions[2] == region_type((1, 0))
m = nplots - 1 m = nplots - 1
assert regions[m] == region_type((m // 2, m % 2)) npt.assert_equal(region2d[m][0].pos, (m // 2, m % 2))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"lcfg", "lcfg",
[LayoutConfig(ncols=3, orientation="v"), LayoutConfig(nrows=3, orientation="v")], [LayoutConfig(ncols=3, orientation="v"), LayoutConfig(nrows=3, orientation="v")],
) )
@pytest.mark.parametrize("region_type", [str, tuple, list]) def test_vlayout(lcfg):
def test_vlayout(lcfg, region_type):
nplots = 7 nplots = 7
layout = RendererLayout(lcfg, nplots) layout = RendererLayout(lcfg, [1] * nplots)
assert layout.ncols == 3 assert layout.wave_ncol == 3
assert layout.nrows == 3 assert layout.wave_nrow == 3
regions = layout.arrange(lambda row, col: region_type((row, col))) region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg)
assert len(regions) == nplots assert len(region2d) == nplots
for i, regions in enumerate(region2d):
assert len(regions) == 1, (i, len(regions))
assert regions[0] == region_type((0, 0)) np.testing.assert_equal(region2d[0][0].pos, (0, 0))
assert regions[2] == region_type((2, 0)) np.testing.assert_equal(region2d[2][0].pos, (2, 0))
assert regions[3] == region_type((0, 1)) np.testing.assert_equal(region2d[3][0].pos, (0, 1))
assert regions[6] == region_type((0, 2)) np.testing.assert_equal(region2d[6][0].pos, (0, 2))
@given(
wave_nchans=hs.lists(hs.integers(1, 10), min_size=1, max_size=100),
orientation=hs.sampled_from(Orientation),
stereo_orientation=hs.sampled_from(StereoOrientation),
nrow_ncol=hs.integers(1, 100),
is_nrows=hs.booleans(),
)
def test_stereo_layout(
orientation: Orientation,
stereo_orientation: StereoOrientation,
wave_nchans: List[int],
nrow_ncol: int,
is_nrows: bool,
):
"""
Not-entirely-rigorous test for layout computation.
Mind-numbingly boring to write (and read?).
Honestly I prefer a good naming scheme in RendererLayout.arrange()
over unit tests.
- This is a regression test...
- And an obstacle to refactoring or feature development.
"""
# region Setup
if is_nrows:
nrows = nrow_ncol
ncols = None
else:
nrows = None
ncols = nrow_ncol
lcfg = LayoutConfig(
orientation=orientation,
nrows=nrows,
ncols=ncols,
stereo_orientation=stereo_orientation,
)
nwaves = len(wave_nchans)
layout = RendererLayout(lcfg, wave_nchans)
# endregion
# Assert layout dimensions correct
assert layout.wave_ncol == ncols or ceildiv(nwaves, nrows)
assert layout.wave_nrow == nrows or ceildiv(nwaves, ncols)
region2d: List[List[RegionSpec]] = layout.arrange(lambda r_spec: r_spec)
# Loop through layout regions
assert len(region2d) == len(wave_nchans)
for wave_i, wave_chans in enumerate(region2d):
stereo_nchan = wave_nchans[wave_i]
assert len(wave_chans) == stereo_nchan
# Compute channel dims within wave.
if stereo_orientation == StereoOrientation.overlay:
chans_per_wave = [1, 1]
elif stereo_orientation == StereoOrientation.v: # pos[0]++
chans_per_wave = [stereo_nchan, 1]
else:
assert stereo_orientation == StereoOrientation.h # pos[1]++
chans_per_wave = [1, stereo_nchan]
# Sanity-check position of channel 0 relative to origin (wave grid).
assert (np.add.reduce(wave_chans[0].pos) != 0) == (wave_i != 0)
npt.assert_equal(wave_chans[0].pos % chans_per_wave, 0)
for chan_j, chan in enumerate(wave_chans):
# Assert 0 <= position < size.
assert chan.pos.shape == chan.size.shape == (2,)
assert (0 <= chan.pos).all()
assert (chan.pos < chan.size).all()
# Sanity-check position of chan relative to origin (wave grid).
npt.assert_equal(
chan.pos // chans_per_wave, wave_chans[0].pos // chans_per_wave
)
# Check position of region (relative to channel 0)
chan_wave_pos = chan.pos - wave_chans[0].pos
if stereo_orientation == StereoOrientation.overlay:
npt.assert_equal(chan_wave_pos, [0, 0])
elif stereo_orientation == StereoOrientation.v: # pos[0]++
npt.assert_equal(chan_wave_pos, [chan_j, 0])
else:
assert stereo_orientation == StereoOrientation.h # pos[1]++
npt.assert_equal(chan_wave_pos, [0, chan_j])
# Check screen edges
screen_edges = chan.screen_edges
assert bool(screen_edges & Edges.Top) == (chan.row == 0)
assert bool(screen_edges & Edges.Left) == (chan.col == 0)
assert bool(screen_edges & Edges.Bottom) == (chan.row == chan.nrow - 1)
assert bool(screen_edges & Edges.Right) == (chan.col == chan.ncol - 1)
# Check stereo edges
wave_edges = chan.wave_edges
if stereo_orientation == StereoOrientation.overlay:
assert wave_edges == ~Edges.NONE
elif stereo_orientation == StereoOrientation.v: # pos[0]++
lr = Edges.Left | Edges.Right
assert wave_edges & lr == lr
assert bool(wave_edges & Edges.Top) == (chan.row % stereo_nchan == 0)
assert bool(wave_edges & Edges.Bottom) == (
(chan.row + 1) % stereo_nchan == 0
)
else:
assert stereo_orientation == StereoOrientation.h # pos[1]++
tb = Edges.Top | Edges.Bottom
assert wave_edges & tb == tb
assert bool(wave_edges & Edges.Left) == (chan.col % stereo_nchan == 0)
assert bool(wave_edges & Edges.Right) == (
(chan.col + 1) % stereo_nchan == 0
)
def test_renderer_layout(): def test_renderer_layout():
@ -69,21 +222,9 @@ def test_renderer_layout():
nplots = 15 nplots = 15
r = MatplotlibRenderer(cfg, lcfg, nplots, None) r = MatplotlibRenderer(cfg, lcfg, nplots, None)
r.render_frame([RENDER_Y_ZEROS] * nplots)
layout = r.layout
# 2 columns, 8 rows # 2 columns, 8 rows
assert r.layout.ncols == 2 assert layout.wave_ncol == 2
assert r.layout.nrows == 8 assert layout.wave_nrow == 8
# Test EdgeFinder
def test_edge_finder():
regions2d = np.arange(24).reshape(3, 8)
edges = EdgeFinder(regions2d)
# Check borders
np.testing.assert_equal(edges.lefts, regions2d[:, 0])
np.testing.assert_equal(edges.rights, regions2d[:, -1])
np.testing.assert_equal(edges.tops, regions2d[0, :])
np.testing.assert_equal(edges.bottoms, regions2d[-1, :])

Wyświetl plik

@ -23,11 +23,8 @@ from corrscope.outputs import (
Stop, Stop,
) )
from corrscope.renderer import RendererConfig, MatplotlibRenderer from corrscope.renderer import RendererConfig, MatplotlibRenderer
from corrscope.settings.paths import MissingFFmpegError from tests.test_renderer import RENDER_Y_ZEROS, WIDTH, HEIGHT
from tests.test_renderer import ALL_ZEROS
WIDTH = 192
HEIGHT = 108
if TYPE_CHECKING: if TYPE_CHECKING:
import pytest_mock import pytest_mock
@ -90,7 +87,7 @@ def test_render_output():
renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1, channel_cfgs=None) renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1, channel_cfgs=None)
out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG) out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG)
renderer.render_frame([ALL_ZEROS]) renderer.render_frame([RENDER_Y_ZEROS])
out.write_frame(renderer.get_frame()) out.write_frame(renderer.get_frame())
assert out.close() == 0 assert out.close() == 0

Wyświetl plik

@ -1,38 +1,47 @@
from typing import Optional from typing import Optional, TYPE_CHECKING
import matplotlib.colors
import numpy as np import numpy as np
import pytest import pytest
from matplotlib.colors import to_rgb
from corrscope.channel import ChannelConfig from corrscope.channel import ChannelConfig
from corrscope.corrscope import CorrScope, default_config, Arguments
from corrscope.layout import LayoutConfig from corrscope.layout import LayoutConfig
from corrscope.outputs import RGB_DEPTH from corrscope.outputs import RGB_DEPTH, FFplayOutputConfig
from corrscope.renderer import RendererConfig, MatplotlibRenderer from corrscope.renderer import RendererConfig, MatplotlibRenderer
from corrscope.wave import Flatten
WIDTH = 640 if TYPE_CHECKING:
HEIGHT = 360 import pytest_mock
ALL_ZEROS = np.array([0, 0]) WIDTH = 64
HEIGHT = 64
RENDER_Y_ZEROS = np.zeros((2, 1))
RENDER_Y_STEREO = np.zeros((2, 2))
OPACITY = 2 / 3
all_colors = pytest.mark.parametrize( all_colors = pytest.mark.parametrize(
"bg_str,fg_str,grid_str", "bg_str,fg_str,grid_str,data",
[ [
("#000000", "#ffffff", None), ("#000000", "#ffffff", None, RENDER_Y_ZEROS),
("#ffffff", "#000000", None), ("#ffffff", "#000000", None, RENDER_Y_ZEROS),
("#0000aa", "#aaaa00", None), ("#0000aa", "#aaaa00", None, RENDER_Y_ZEROS),
("#aaaa00", "#0000aa", None), ("#aaaa00", "#0000aa", None, RENDER_Y_ZEROS),
# Enabling gridlines enables Axes rectangles. # Enabling ~~beautiful magenta~~ gridlines enables Axes rectangles.
# Make sure they don't draw *over* the global figure background. # Make sure bg is disabled, so they don't overwrite global figure background.
("#0000aa", "#aaaa00", "#ff00ff"), # beautiful magenta gridlines ("#0000aa", "#aaaa00", "#ff00ff", RENDER_Y_ZEROS),
("#aaaa00", "#0000aa", "#ff00ff"), ("#aaaa00", "#0000aa", "#ff00ff", RENDER_Y_ZEROS),
("#0000aa", "#aaaa00", "#ff00ff", RENDER_Y_STEREO),
("#aaaa00", "#0000aa", "#ff00ff", RENDER_Y_STEREO),
], ],
) )
nplots = 2 NPLOTS = 2
@all_colors @all_colors
def test_default_colors(bg_str, fg_str, grid_str): def test_default_colors(bg_str, fg_str, grid_str, data):
""" Test the default background/foreground colors. """ """ Test the default background/foreground colors. """
cfg = RendererConfig( cfg = RendererConfig(
WIDTH, WIDTH,
@ -40,23 +49,24 @@ def test_default_colors(bg_str, fg_str, grid_str):
bg_color=bg_str, bg_color=bg_str,
init_line_color=fg_str, init_line_color=fg_str,
grid_color=grid_str, grid_color=grid_str,
stereo_grid_opacity=OPACITY,
line_width=2.0, line_width=2.0,
antialiasing=False, antialiasing=False,
) )
lcfg = LayoutConfig() lcfg = LayoutConfig()
r = MatplotlibRenderer(cfg, lcfg, nplots, None) r = MatplotlibRenderer(cfg, lcfg, NPLOTS, None)
verify(r, bg_str, fg_str, grid_str) verify(r, bg_str, fg_str, grid_str, data)
# Ensure default ChannelConfig(line_color=None) does not override line color # Ensure default ChannelConfig(line_color=None) does not override line color
chan = ChannelConfig(wav_path="") chan = ChannelConfig(wav_path="")
channels = [chan] * nplots channels = [chan] * NPLOTS
r = MatplotlibRenderer(cfg, lcfg, nplots, channels) r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
verify(r, bg_str, fg_str, grid_str) verify(r, bg_str, fg_str, grid_str, data)
@all_colors @all_colors
def test_line_colors(bg_str, fg_str, grid_str): def test_line_colors(bg_str, fg_str, grid_str, data):
""" Test channel-specific line color overrides """ """ Test channel-specific line color overrides """
cfg = RendererConfig( cfg = RendererConfig(
WIDTH, WIDTH,
@ -64,30 +74,44 @@ def test_line_colors(bg_str, fg_str, grid_str):
bg_color=bg_str, bg_color=bg_str,
init_line_color="#888888", init_line_color="#888888",
grid_color=grid_str, grid_color=grid_str,
stereo_grid_opacity=OPACITY,
line_width=2.0, line_width=2.0,
antialiasing=False, antialiasing=False,
) )
lcfg = LayoutConfig() lcfg = LayoutConfig()
chan = ChannelConfig(wav_path="", line_color=fg_str) chan = ChannelConfig(wav_path="", line_color=fg_str)
channels = [chan] * nplots channels = [chan] * NPLOTS
r = MatplotlibRenderer(cfg, lcfg, nplots, channels) r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
verify(r, bg_str, fg_str, grid_str) verify(r, bg_str, fg_str, grid_str, data)
def verify(r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str]): TOLERANCE = 3
r.render_frame([ALL_ZEROS] * nplots)
def verify(
r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str], data: np.ndarray
):
r.render_frame([data] * NPLOTS)
frame_colors: np.ndarray = np.frombuffer(r.get_frame(), dtype=np.uint8).reshape( frame_colors: np.ndarray = np.frombuffer(r.get_frame(), dtype=np.uint8).reshape(
(-1, RGB_DEPTH) (-1, RGB_DEPTH)
) )
bg_u8 = [round(c * 255) for c in to_rgb(bg_str)] bg_u8 = to_rgb(bg_str)
fg_u8 = [round(c * 255) for c in to_rgb(fg_str)] fg_u8 = to_rgb(fg_str)
all_colors = [bg_u8, fg_u8] all_colors = [bg_u8, fg_u8]
if grid_str: if grid_str:
grid_u8 = [round(c * 255) for c in to_rgb(grid_str)] grid_u8 = to_rgb(grid_str)
all_colors.append(grid_u8) all_colors.append(grid_u8)
else:
grid_u8 = bg_u8
assert (data.shape[1] > 1) == (data is RENDER_Y_STEREO)
is_stereo = data.shape[1] > 1
if is_stereo:
stereo_grid_u8 = (grid_u8 * OPACITY + bg_u8 * (1 - OPACITY)).astype(int)
all_colors.append(stereo_grid_u8)
# Ensure background is correct # Ensure background is correct
bg_frame = frame_colors[0] bg_frame = frame_colors[0]
@ -104,5 +128,36 @@ def verify(r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str]):
if grid_str: if grid_str:
assert np.prod(frame_colors == grid_u8, axis=-1).any(), "Missing grid_str" assert np.prod(frame_colors == grid_u8, axis=-1).any(), "Missing grid_str"
# Ensure stereo grid color is present
if is_stereo:
assert (
np.min(np.sum(np.abs(frame_colors - stereo_grid_u8), axis=-1)) < TOLERANCE
), "Missing stereo gridlines"
assert (np.amax(frame_colors, axis=0) == np.amax(all_colors, axis=0)).all() assert (np.amax(frame_colors, axis=0) == np.amax(all_colors, axis=0)).all()
assert (np.amin(frame_colors, axis=0) == np.amin(all_colors, axis=0)).all() assert (np.amin(frame_colors, axis=0) == np.amin(all_colors, axis=0)).all()
def to_rgb(c) -> np.ndarray:
to_rgb = matplotlib.colors.to_rgb
return np.array([round(c * 255) for c in to_rgb(c)], dtype=int)
# Stereo *renderer* integration tests.
def test_stereo_render_integration(mocker: "pytest_mock.MockFixture"):
"""Ensure corrscope plays/renders in stereo, without crashing."""
# Stub out FFplay output.
mocker.patch.object(FFplayOutputConfig, "cls")
# Render in stereo.
cfg = default_config(
channels=[ChannelConfig("tests/stereo in-phase.wav")],
render_stereo=Flatten.Stereo,
end_time=0.5, # Reduce test duration
render=RendererConfig(WIDTH, HEIGHT),
)
# Make sure it doesn't crash.
corr = CorrScope(cfg, Arguments(".", [FFplayOutputConfig()]))
corr.play()

Wyświetl plik

@ -79,6 +79,7 @@ AllFlattens = Flatten.__members__.values()
@pytest.mark.parametrize("flatten", AllFlattens) @pytest.mark.parametrize("flatten", AllFlattens)
@pytest.mark.parametrize("return_channels", [False, True])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"path,nchan,peaks", "path,nchan,peaks",
[ [
@ -88,19 +89,30 @@ AllFlattens = Flatten.__members__.values()
], ],
) )
def test_stereo_flatten_modes( def test_stereo_flatten_modes(
flatten: Flatten, path: str, nchan: int, peaks: Sequence[float] flatten: Flatten,
return_channels: bool,
path: str,
nchan: int,
peaks: Sequence[float],
): ):
"""Ensures all Flatten modes are handled properly """Ensures all Flatten modes are handled properly
for stereo and mono signals.""" for stereo and mono signals."""
# return_channels=False <-> triggering.
# flatten=stereo -> rendering.
# These conditions do not currently coexist.
# if not return_channels and flatten == Flatten.Stereo:
# return
assert nchan == len(peaks) assert nchan == len(peaks)
wave = Wave(path) wave = Wave(path)
if flatten not in Flatten.modes: if flatten not in Flatten.modes:
with pytest.raises(CorrError): with pytest.raises(CorrError):
wave.with_flatten(flatten) wave.with_flatten(flatten, return_channels)
return return
else: else:
wave = wave.with_flatten(flatten) wave = wave.with_flatten(flatten, return_channels)
nsamp = wave.nsamp nsamp = wave.nsamp
data = wave[:] data = wave[:]
@ -111,7 +123,10 @@ def test_stereo_flatten_modes(
for chan_data, peak in zip(data.T, peaks): for chan_data, peak in zip(data.T, peaks):
assert_full_scale(chan_data, peak) assert_full_scale(chan_data, peak)
else: else:
assert data.shape == (nsamp,) if return_channels:
assert data.shape == (nsamp, 1)
else:
assert data.shape == (nsamp,)
# If DiffAvg and in-phase, L-R=0. # If DiffAvg and in-phase, L-R=0.
if flatten == Flatten.DiffAvg: if flatten == Flatten.DiffAvg: