kopia lustrzana https://github.com/corrscope/corrscope
rodzic
d1bdf0cb34
commit
0dbba8ac09
|
@ -62,8 +62,8 @@ class Channel:
|
|||
tflat = coalesce(cfg.trigger_stereo, corr_cfg.trigger_stereo)
|
||||
rflat = coalesce(cfg.render_stereo, corr_cfg.render_stereo)
|
||||
|
||||
self.trigger_wave = wave.with_flatten(tflat)
|
||||
self.render_wave = wave.with_flatten(rflat)
|
||||
self.trigger_wave = wave.with_flatten(tflat, return_channels=False)
|
||||
self.render_wave = wave.with_flatten(rflat, return_channels=True)
|
||||
|
||||
# `subsampling` increases `stride` and decreases `nsamp`.
|
||||
# `width` increases `stride` without changing `nsamp`.
|
||||
|
|
|
@ -1,16 +1,41 @@
|
|||
from typing import Optional, TypeVar, Callable, List, Generic, Tuple
|
||||
import collections
|
||||
import enum
|
||||
from enum import auto
|
||||
from typing import Optional, TypeVar, Callable, List, Iterable
|
||||
|
||||
import attr
|
||||
import numpy as np
|
||||
|
||||
from corrscope.config import DumpableAttrs, CorrError
|
||||
from corrscope.config import DumpableAttrs, CorrError, DumpEnumAsStr
|
||||
from corrscope.util import ceildiv
|
||||
|
||||
|
||||
class LayoutConfig(DumpableAttrs, always_dump="orientation"):
|
||||
orientation: str = "h"
|
||||
class Orientation(str, DumpEnumAsStr):
|
||||
h = "h"
|
||||
v = "v"
|
||||
|
||||
|
||||
class StereoOrientation(str, DumpEnumAsStr):
|
||||
h = "h"
|
||||
v = "v"
|
||||
overlay = "overlay"
|
||||
|
||||
|
||||
assert Orientation.h == StereoOrientation.h
|
||||
H = Orientation.h
|
||||
V = Orientation.v
|
||||
OVERLAY = StereoOrientation.overlay
|
||||
|
||||
|
||||
class LayoutConfig(DumpableAttrs, always_dump="orientation stereo_orientation"):
|
||||
orientation: Orientation = attr.ib(default="h", converter=Orientation)
|
||||
nrows: Optional[int] = None
|
||||
ncols: Optional[int] = None
|
||||
|
||||
stereo_orientation: StereoOrientation = attr.ib(
|
||||
default="h", converter=StereoOrientation
|
||||
)
|
||||
|
||||
def __attrs_post_init__(self) -> None:
|
||||
if not self.nrows:
|
||||
self.nrows = None
|
||||
|
@ -24,31 +49,90 @@ class LayoutConfig(DumpableAttrs, always_dump="orientation"):
|
|||
self.ncols = 1
|
||||
|
||||
|
||||
class Edges(enum.Flag):
|
||||
NONE = 0
|
||||
Top = auto()
|
||||
Left = auto()
|
||||
Bottom = auto()
|
||||
Right = auto()
|
||||
|
||||
@staticmethod
|
||||
def at(nrows: int, ncols: int, row: int, col: int):
|
||||
if not nrows > 0:
|
||||
raise ValueError(f"invalid nrows={nrows}, must be positive")
|
||||
if not ncols > 0:
|
||||
raise ValueError(f"invalid ncols={ncols}, must be positive")
|
||||
if not 0 <= row < nrows:
|
||||
raise ValueError(f"invalid row={row} not in [0 .. nrows={nrows})")
|
||||
if not 0 <= col < ncols:
|
||||
raise ValueError(f"invalid col={col} not in [0 .. ncols={ncols})")
|
||||
|
||||
ret = Edges.NONE
|
||||
if row == 0:
|
||||
ret |= Edges.Top
|
||||
if row + 1 == nrows:
|
||||
ret |= Edges.Bottom
|
||||
if col == 0:
|
||||
ret |= Edges.Left
|
||||
if col + 1 == ncols:
|
||||
ret |= Edges.Right
|
||||
return ret
|
||||
|
||||
|
||||
def attr_idx_property(key: str, idx: int) -> property:
|
||||
@property
|
||||
def inner(self: "RegionSpec"):
|
||||
return getattr(self, key)[idx]
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
@attr.dataclass
|
||||
class RegionSpec:
|
||||
"""
|
||||
- Origin is located at top-left.
|
||||
- Row 0 = top.
|
||||
- Row nrows-1 = bottom.
|
||||
- Col 0 = left.
|
||||
- Col ncols-1 = right.
|
||||
"""
|
||||
|
||||
size: np.ndarray
|
||||
pos: np.ndarray
|
||||
|
||||
nrow = attr_idx_property("size", 0)
|
||||
ncol = attr_idx_property("size", 1)
|
||||
row = attr_idx_property("pos", 0)
|
||||
col = attr_idx_property("pos", 1)
|
||||
|
||||
screen_edges: "Edges"
|
||||
wave_edges: "Edges"
|
||||
|
||||
|
||||
Region = TypeVar("Region")
|
||||
RegionFactory = Callable[[int, int], Region] # f(row, column) -> Region
|
||||
RegionFactory = Callable[[RegionSpec], Region] # f(row, column) -> Region
|
||||
|
||||
|
||||
class RendererLayout:
|
||||
VALID_ORIENTATIONS = ["h", "v"]
|
||||
|
||||
def __init__(self, cfg: LayoutConfig, nplots: int):
|
||||
def __init__(self, cfg: LayoutConfig, wave_nchans: List[int]):
|
||||
self.cfg = cfg
|
||||
self.nplots = nplots
|
||||
|
||||
# Setup layout
|
||||
self.nrows, self.ncols = self._calc_layout()
|
||||
self.nwaves = len(wave_nchans)
|
||||
self.wave_nchans = wave_nchans
|
||||
|
||||
self.orientation = cfg.orientation
|
||||
if self.orientation not in self.VALID_ORIENTATIONS:
|
||||
raise CorrError(
|
||||
f"Invalid orientation {self.orientation} not in "
|
||||
f"{self.VALID_ORIENTATIONS}"
|
||||
)
|
||||
self.stereo_orientation = cfg.stereo_orientation
|
||||
|
||||
def _calc_layout(self) -> Tuple[int, int]:
|
||||
# Setup layout
|
||||
self._calc_layout()
|
||||
|
||||
# Shape of wave slots
|
||||
wave_nrow: int
|
||||
wave_ncol: int
|
||||
|
||||
def _calc_layout(self) -> None:
|
||||
"""
|
||||
Inputs: self.cfg, self.waves
|
||||
:return: (nrows, ncols)
|
||||
Inputs: self.cfg, self.stereo_nchan
|
||||
Outputs: self.wave_nrow, ncol
|
||||
"""
|
||||
cfg = self.cfg
|
||||
|
||||
|
@ -56,7 +140,7 @@ class RendererLayout:
|
|||
nrows = cfg.nrows
|
||||
if nrows is None:
|
||||
raise ValueError("impossible cfg: nrows is None and true")
|
||||
ncols = ceildiv(self.nplots, nrows)
|
||||
ncols = ceildiv(self.nwaves, nrows)
|
||||
else:
|
||||
if cfg.ncols is None:
|
||||
raise ValueError(
|
||||
|
@ -64,38 +148,110 @@ class RendererLayout:
|
|||
"(__attrs_post_init__ not called?)"
|
||||
)
|
||||
ncols = cfg.ncols
|
||||
nrows = ceildiv(self.nplots, ncols)
|
||||
nrows = ceildiv(self.nwaves, ncols)
|
||||
|
||||
return nrows, ncols
|
||||
self.wave_nrow = nrows
|
||||
self.wave_ncol = ncols
|
||||
|
||||
def arrange(self, region_factory: RegionFactory[Region]) -> List[Region]:
|
||||
""" Generates an array of regions.
|
||||
|
||||
index, row, column are fed into region_factory in a row-major order [row][col].
|
||||
The results are possibly reshaped into column-major order [col][row].
|
||||
def arrange(self, region_factory: RegionFactory[Region]) -> List[List[Region]]:
|
||||
"""
|
||||
nspaces = self.nrows * self.ncols
|
||||
inds = np.arange(nspaces)
|
||||
rows, cols = np.unravel_index(inds, (self.nrows, self.ncols))
|
||||
(row, column) are fed into region_factory in a row-major order [row][col].
|
||||
Stereo channel pairs are extracted.
|
||||
The results are possibly reshaped into column-major order [col][row].
|
||||
|
||||
row_col = list(zip(rows, cols))
|
||||
regions = np.empty(len(row_col), dtype=object) # type: np.ndarray[Region]
|
||||
regions[:] = [region_factory(*rc) for rc in row_col]
|
||||
:return arr[wave][channel] = Region
|
||||
"""
|
||||
|
||||
regions2d = regions.reshape(
|
||||
(self.nrows, self.ncols)
|
||||
) # type: np.ndarray[Region]
|
||||
wave_spaces = self.wave_nrow * self.wave_ncol
|
||||
inds = np.arange(wave_spaces)
|
||||
|
||||
# if column major:
|
||||
if self.orientation == "v":
|
||||
regions2d = regions2d.T
|
||||
# Compute location of each wave.
|
||||
if self.orientation == V:
|
||||
# column major
|
||||
cols, rows = np.unravel_index(inds, (self.wave_ncol, self.wave_nrow))
|
||||
else:
|
||||
# row major
|
||||
rows, cols = np.unravel_index(inds, (self.wave_nrow, self.wave_ncol))
|
||||
|
||||
return regions2d.flatten()[: self.nplots].tolist()
|
||||
# Generate plot for each wave.chan. Leave unused slots empty.
|
||||
region_wave_chan: List[List[Region]] = []
|
||||
|
||||
# The order of (rows, cols) has no effect.
|
||||
for stereo_nchan, wave_row, wave_col in zip(self.wave_nchans, rows, cols):
|
||||
# Wave = within Screen.
|
||||
# Chan = within Wave, generate a plot.
|
||||
# All arrays are [y, x] == [row, col].
|
||||
|
||||
# Wave dim/pos (within screen)
|
||||
waves_per_screen = arr(self.wave_nrow, self.wave_ncol)
|
||||
wave_screen_pos = arr(wave_row, wave_col)
|
||||
del wave_row, wave_col
|
||||
|
||||
# Channel dim/pos (within wave)
|
||||
chans_per_wave = arr(1, 1) # Mutated based on orientation
|
||||
chan_wave_pos = arr(0, 0) # Mutated in for-chan loop.
|
||||
|
||||
# Distance between chans
|
||||
dchan = arr(0, 0) # Mutated based on orientation
|
||||
|
||||
if self.stereo_orientation == V:
|
||||
chans_per_wave[0] = stereo_nchan
|
||||
dchan[0] = 1
|
||||
elif self.stereo_orientation == H:
|
||||
chans_per_wave[1] = stereo_nchan
|
||||
dchan[1] = 1
|
||||
else:
|
||||
assert self.stereo_orientation == OVERLAY
|
||||
|
||||
# Channel dim/pos (within screen)
|
||||
chans_per_screen = chans_per_wave * waves_per_screen
|
||||
chan_screen_pos = chans_per_wave * wave_screen_pos
|
||||
|
||||
# Generate plots for each channel
|
||||
region_chan: List[Region] = []
|
||||
region_wave_chan.append(region_chan)
|
||||
region = None
|
||||
|
||||
for chan in range(stereo_nchan):
|
||||
assert (chan_wave_pos < chans_per_wave).all()
|
||||
assert (chan_screen_pos < chans_per_screen).all()
|
||||
|
||||
# Generate plot (channel position in screen)
|
||||
if region is None or dchan.any():
|
||||
screen_edges = Edges.at(*chans_per_screen, *chan_screen_pos)
|
||||
wave_edges = Edges.at(*chans_per_wave, *chan_wave_pos)
|
||||
|
||||
# Removing .copy() causes bugs if region_factory() holds
|
||||
# mutable references.
|
||||
chan_spec = RegionSpec(
|
||||
chans_per_screen.copy(),
|
||||
chan_screen_pos.copy(),
|
||||
screen_edges,
|
||||
wave_edges,
|
||||
)
|
||||
region = region_factory(chan_spec)
|
||||
region_chan.append(region)
|
||||
|
||||
# Move to next channel position
|
||||
chan_screen_pos += dchan
|
||||
chan_wave_pos += dchan
|
||||
|
||||
assert len(region_wave_chan) == self.nwaves
|
||||
return region_wave_chan
|
||||
|
||||
|
||||
class EdgeFinder(Generic[Region]):
|
||||
def __init__(self, regions2d: np.ndarray):
|
||||
self.tops: List[Region] = regions2d[0, :].tolist()
|
||||
self.bottoms: List[Region] = regions2d[-1, :].tolist()
|
||||
self.lefts: List[Region] = regions2d[:, 0].tolist()
|
||||
self.rights: List[Region] = regions2d[:, -1].tolist()
|
||||
def arr(*args):
|
||||
return np.array(args)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def unique_by_id(items: Iterable[T]) -> List[T]:
|
||||
seen = collections.OrderedDict()
|
||||
|
||||
for item in items:
|
||||
if id(item) not in seen:
|
||||
seen[id(item)] = item
|
||||
|
||||
return list(seen.values())
|
||||
|
|
|
@ -4,10 +4,17 @@ from typing import Optional, List, TYPE_CHECKING, Any
|
|||
|
||||
import attr
|
||||
import matplotlib
|
||||
import matplotlib.colors
|
||||
import numpy as np
|
||||
|
||||
from corrscope.config import DumpableAttrs, with_units
|
||||
from corrscope.layout import RendererLayout, LayoutConfig, EdgeFinder
|
||||
from corrscope.layout import (
|
||||
RendererLayout,
|
||||
LayoutConfig,
|
||||
unique_by_id,
|
||||
RegionSpec,
|
||||
Edges,
|
||||
)
|
||||
from corrscope.outputs import RGB_DEPTH, ByteBuffer
|
||||
from corrscope.util import coalesce
|
||||
|
||||
|
@ -58,7 +65,10 @@ class RendererConfig(DumpableAttrs, always_dump="*"):
|
|||
|
||||
bg_color: str = "#000000"
|
||||
init_line_color: str = default_color()
|
||||
|
||||
grid_color: Optional[str] = None
|
||||
stereo_grid_opacity: float = 0.5
|
||||
|
||||
midline_color: Optional[str] = None
|
||||
v_midline: bool = False
|
||||
h_midline: bool = False
|
||||
|
@ -99,8 +109,8 @@ class Renderer(ABC):
|
|||
channel_cfgs: Optional[List["ChannelConfig"]],
|
||||
):
|
||||
self.cfg = cfg
|
||||
self.lcfg = lcfg
|
||||
self.nplots = nplots
|
||||
self.layout = RendererLayout(lcfg, nplots)
|
||||
|
||||
# Load line colors.
|
||||
if channel_cfgs is not None:
|
||||
|
@ -165,16 +175,20 @@ class MatplotlibRenderer(Renderer):
|
|||
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
|
||||
)
|
||||
|
||||
# Flat array of nrows*ncols elements, ordered by cfg.rows_first.
|
||||
self._fig: "Figure"
|
||||
self._axes: List["Axes"] # set by set_layout()
|
||||
self._lines: Optional[List["Line2D"]] = None # set by render_frame() first call
|
||||
|
||||
self._set_layout() # mutates self
|
||||
# _axes2d[wave][chan] = Axes
|
||||
self._axes2d: List[List["Axes"]] # set by set_layout()
|
||||
|
||||
# _lines2d[wave][chan] = Line2D
|
||||
self._lines2d: List[List[Line2D]] = []
|
||||
self._lines_flat: List["Line2D"] = []
|
||||
|
||||
transparent = "#00000000"
|
||||
|
||||
def _set_layout(self) -> None:
|
||||
layout: RendererLayout
|
||||
|
||||
def _set_layout(self, wave_nchans: List[int]) -> None:
|
||||
"""
|
||||
Creates a flat array of Matplotlib Axes, with the new layout.
|
||||
Opens a window showing the Figure (and Axes).
|
||||
|
@ -183,6 +197,8 @@ class MatplotlibRenderer(Renderer):
|
|||
Outputs: self.nrows, self.ncols, self.axes
|
||||
"""
|
||||
|
||||
self.layout = RendererLayout(self.lcfg, wave_nchans)
|
||||
|
||||
# Create Axes
|
||||
# https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html
|
||||
if hasattr(self, "_fig"):
|
||||
|
@ -190,24 +206,24 @@ class MatplotlibRenderer(Renderer):
|
|||
# plt.close(self.fig)
|
||||
|
||||
grid_color = self.cfg.grid_color
|
||||
axes2d: np.ndarray["Axes"]
|
||||
self._fig = Figure()
|
||||
FigureCanvasAgg(self._fig)
|
||||
|
||||
axes2d = self._fig.subplots(
|
||||
self.layout.nrows,
|
||||
self.layout.ncols,
|
||||
squeeze=False,
|
||||
# Remove axis ticks (which slow down rendering)
|
||||
subplot_kw=dict(xticks=[], yticks=[]),
|
||||
# Remove gaps between Axes TODO borders shouldn't be half-visible
|
||||
gridspec_kw=dict(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0),
|
||||
)
|
||||
# RegionFactory
|
||||
def axes_factory(r: RegionSpec) -> "Axes":
|
||||
width = 1 / r.ncol
|
||||
left = r.col / r.ncol
|
||||
assert 0 <= left < 1
|
||||
|
||||
ax: "Axes"
|
||||
if grid_color:
|
||||
# Initialize borders
|
||||
for ax in axes2d.flatten():
|
||||
height = 1 / r.nrow
|
||||
bottom = (r.nrow - r.row - 1) / r.nrow
|
||||
assert 0 <= bottom < 1
|
||||
|
||||
# Disabling xticks/yticks is unnecessary, since we hide Axises.
|
||||
ax = self._fig.add_axes([left, bottom, width, height], xticks=[], yticks=[])
|
||||
|
||||
if grid_color:
|
||||
# Initialize borders
|
||||
# Hide Axises
|
||||
# (drawing them is very slow, and we disable ticks+labels anyway)
|
||||
ax.get_xaxis().set_visible(False)
|
||||
|
@ -223,25 +239,44 @@ class MatplotlibRenderer(Renderer):
|
|||
for spine in ax.spines.values():
|
||||
spine.set_color(grid_color)
|
||||
|
||||
# gridspec_kw indexes from bottom-left corner.
|
||||
# Only show bottom-left borders (x=0, y=0)
|
||||
ax.spines["top"].set_visible(False)
|
||||
ax.spines["right"].set_visible(False)
|
||||
def hide(key: str):
|
||||
ax.spines[key].set_visible(False)
|
||||
|
||||
# Hide bottom-left edges for speed.
|
||||
edge_axes: EdgeFinder["Axes"] = EdgeFinder(axes2d)
|
||||
for ax in edge_axes.bottoms:
|
||||
ax.spines["bottom"].set_visible(False)
|
||||
for ax in edge_axes.lefts:
|
||||
ax.spines["left"].set_visible(False)
|
||||
# Hide all axes except bottom-right.
|
||||
hide("top")
|
||||
hide("left")
|
||||
|
||||
else:
|
||||
# Remove Axis from Axes
|
||||
for ax in axes2d.flatten():
|
||||
# If bottom of screen, hide bottom. If right of screen, hide right.
|
||||
if r.screen_edges & Edges.Bottom:
|
||||
hide("bottom")
|
||||
if r.screen_edges & Edges.Right:
|
||||
hide("right")
|
||||
|
||||
# Dim stereo gridlines
|
||||
if self.cfg.stereo_grid_opacity > 0:
|
||||
dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
|
||||
dim_color[-1] = self.cfg.stereo_grid_opacity
|
||||
|
||||
def dim(key: str):
|
||||
ax.spines[key].set_color(dim_color)
|
||||
|
||||
else:
|
||||
dim = hide
|
||||
|
||||
# If not bottom of wave, dim bottom. If not right of wave, dim right.
|
||||
if not r.wave_edges & Edges.Bottom:
|
||||
dim("bottom")
|
||||
if not r.wave_edges & Edges.Right:
|
||||
dim("right")
|
||||
|
||||
else:
|
||||
ax.set_axis_off()
|
||||
|
||||
# Generate arrangement (using nplots, cfg.orientation)
|
||||
self._axes: List[Axes] = self.layout.arrange(lambda row, col: axes2d[row, col])
|
||||
return ax
|
||||
|
||||
# Generate arrangement (using self.lcfg, wave_nchans)
|
||||
# _axes2d[wave][chan] = Axes
|
||||
self._axes2d = self.layout.arrange(axes_factory)
|
||||
|
||||
# Setup figure geometry
|
||||
self._fig.set_dpi(DPI)
|
||||
|
@ -255,45 +290,66 @@ class MatplotlibRenderer(Renderer):
|
|||
)
|
||||
|
||||
# Initialize axes and draw waveform data
|
||||
if self._lines is None:
|
||||
if not self._lines2d:
|
||||
assert len(datas[0].shape) == 2, datas[0].shape
|
||||
|
||||
wave_nchans = [data.shape[1] for data in datas]
|
||||
self._set_layout(wave_nchans)
|
||||
|
||||
cfg = self.cfg
|
||||
|
||||
# Setup background/axes
|
||||
self._fig.set_facecolor(cfg.bg_color)
|
||||
for idx, data in enumerate(datas):
|
||||
ax = self._axes[idx]
|
||||
max_x = len(data) - 1
|
||||
ax.set_xlim(0, max_x)
|
||||
ax.set_ylim(-1, 1)
|
||||
for idx, wave_data in enumerate(datas):
|
||||
wave_axes = self._axes2d[idx]
|
||||
for ax in unique_by_id(wave_axes):
|
||||
max_x = len(wave_data) - 1
|
||||
ax.set_xlim(0, max_x)
|
||||
ax.set_ylim(-1, 1)
|
||||
|
||||
# Setup midlines
|
||||
midline_color = cfg.midline_color
|
||||
midline_width = pixels(1)
|
||||
# Setup midlines (depends on max_x and wave_data)
|
||||
midline_color = cfg.midline_color
|
||||
midline_width = pixels(1)
|
||||
|
||||
# zorder=-100 still draws on top of gridlines :(
|
||||
kw = dict(color=midline_color, linewidth=midline_width)
|
||||
if cfg.v_midline:
|
||||
ax.axvline(x=max_x / 2, **kw)
|
||||
if cfg.h_midline:
|
||||
ax.axhline(y=0, **kw)
|
||||
# zorder=-100 still draws on top of gridlines :(
|
||||
kw = dict(color=midline_color, linewidth=midline_width)
|
||||
if cfg.v_midline:
|
||||
ax.axvline(x=max_x / 2, **kw)
|
||||
if cfg.h_midline:
|
||||
ax.axhline(y=0, **kw)
|
||||
|
||||
self._save_background()
|
||||
|
||||
# Plot lines over background
|
||||
line_width = pixels(cfg.line_width)
|
||||
self._lines = []
|
||||
|
||||
for idx, data in enumerate(datas):
|
||||
ax = self._axes[idx]
|
||||
line_color = self._line_params[idx].color
|
||||
line = ax.plot(data, color=line_color, linewidth=line_width)[0]
|
||||
self._lines.append(line)
|
||||
# Foreach wave
|
||||
for wave_idx, wave_data in enumerate(datas):
|
||||
wave_axes = self._axes2d[wave_idx]
|
||||
wave_lines = []
|
||||
|
||||
# Foreach chan
|
||||
for chan_idx, chan_data in enumerate(wave_data.T):
|
||||
ax = wave_axes[chan_idx]
|
||||
line_color = self._line_params[wave_idx].color
|
||||
chan_line: Line2D = ax.plot(
|
||||
chan_data, color=line_color, linewidth=line_width
|
||||
)[0]
|
||||
wave_lines.append(chan_line)
|
||||
|
||||
self._lines2d.append(wave_lines)
|
||||
self._lines_flat.extend(wave_lines)
|
||||
|
||||
# Draw waveform data
|
||||
else:
|
||||
for idx, data in enumerate(datas):
|
||||
line = self._lines[idx]
|
||||
line.set_ydata(data)
|
||||
# Foreach wave
|
||||
for wave_idx, wave_data in enumerate(datas):
|
||||
wave_lines = self._lines2d[wave_idx]
|
||||
|
||||
# Foreach chan
|
||||
for chan_idx, chan_data in enumerate(wave_data.T):
|
||||
chan_line = wave_lines[chan_idx]
|
||||
chan_line.set_ydata(chan_data)
|
||||
|
||||
self._redraw_over_background()
|
||||
|
||||
|
@ -314,8 +370,7 @@ class MatplotlibRenderer(Renderer):
|
|||
canvas: FigureCanvasAgg = self._fig.canvas
|
||||
canvas.restore_region(self.bg_cache)
|
||||
|
||||
assert self._lines is not None
|
||||
for line in self._lines:
|
||||
for line in self._lines_flat:
|
||||
line.axes.draw_artist(line)
|
||||
|
||||
# https://bastibe.de/2013-05-30-speeding-up-matplotlib.html
|
||||
|
|
|
@ -40,7 +40,7 @@ class Wave:
|
|||
__slots__ = """
|
||||
wave_path
|
||||
amplification
|
||||
smp_s data _flatten is_mono
|
||||
smp_s data return_channels _flatten is_mono
|
||||
nsamp dtype
|
||||
center max_val
|
||||
""".split()
|
||||
|
@ -91,6 +91,7 @@ class Wave:
|
|||
assert self.data.ndim in [1, 2]
|
||||
self.is_mono = self.data.ndim == 1
|
||||
self.flatten = flatten
|
||||
self.return_channels = False
|
||||
|
||||
# Cast self.data to stereo (nsamp, nchan)
|
||||
if self.is_mono:
|
||||
|
@ -130,9 +131,10 @@ class Wave:
|
|||
else:
|
||||
raise CorrError(f"unexpected wavfile dtype {dtype}")
|
||||
|
||||
def with_flatten(self, flatten: Flatten) -> "Wave":
|
||||
def with_flatten(self, flatten: Flatten, return_channels: bool) -> "Wave":
|
||||
new = copy.copy(self)
|
||||
new.flatten = flatten
|
||||
new.return_channels = return_channels
|
||||
return new
|
||||
|
||||
def __getitem__(self, index: Union[int, slice]) -> np.ndarray:
|
||||
|
@ -154,6 +156,9 @@ class Wave:
|
|||
|
||||
data -= self.center
|
||||
data *= self.amplification / self.max_val
|
||||
|
||||
if self.return_channels and len(data.shape) == 1:
|
||||
data = data.reshape(-1, 1)
|
||||
return data
|
||||
|
||||
def _get(self, begin: int, end: int, subsampling: int) -> np.ndarray:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""
|
||||
Integration tests found in:
|
||||
- test_cli.py
|
||||
- test_renderer.py
|
||||
- test_output.py
|
||||
"""
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ from corrscope.channel import ChannelConfig, Channel
|
|||
from corrscope.corrscope import default_config, CorrScope, BenchmarkMode, Arguments
|
||||
from corrscope.triggers import NullTriggerConfig
|
||||
from corrscope.util import coalesce
|
||||
from corrscope.wave import Flatten
|
||||
|
||||
|
||||
positive = hs.integers(min_value=1, max_value=100)
|
||||
|
@ -134,3 +135,34 @@ def test_config_channel_width_stride(
|
|||
|
||||
|
||||
# line_color is tested in test_renderer.py
|
||||
|
||||
|
||||
@pytest.mark.parametrize("filename", ["tests/sine440.wav", "tests/stereo in-phase.wav"])
|
||||
@pytest.mark.parametrize(
|
||||
("global_stereo", "chan_stereo"),
|
||||
[
|
||||
[Flatten.SumAvg, None],
|
||||
[Flatten.Stereo, None],
|
||||
[Flatten.SumAvg, Flatten.Stereo],
|
||||
[Flatten.Stereo, Flatten.SumAvg],
|
||||
],
|
||||
)
|
||||
def test_per_channel_stereo(
|
||||
filename: str, global_stereo: Flatten, chan_stereo: Optional[Flatten]
|
||||
):
|
||||
"""Ensure you can enable/disable stereo on a per-channel basis."""
|
||||
stereo = coalesce(chan_stereo, global_stereo)
|
||||
|
||||
# Test render wave.
|
||||
cfg = default_config(render_stereo=global_stereo)
|
||||
ccfg = ChannelConfig("tests/stereo in-phase.wav", render_stereo=chan_stereo)
|
||||
channel = Channel(ccfg, cfg)
|
||||
|
||||
# Render wave *must* return stereo.
|
||||
assert channel.render_wave[0:1].ndim == 2
|
||||
data = channel.render_wave.get_around(0, return_nsamp=4, stride=1)
|
||||
assert data.ndim == 2
|
||||
|
||||
if "stereo" in filename:
|
||||
assert channel.render_wave._flatten == stereo
|
||||
assert data.shape[1] == (2 if stereo is Flatten.Stereo else 1)
|
||||
|
|
|
@ -1,9 +1,22 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
from typing import List
|
||||
|
||||
from corrscope.layout import LayoutConfig, RendererLayout, EdgeFinder
|
||||
import hypothesis.strategies as hs
|
||||
import numpy as np
|
||||
import numpy.testing as npt
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
|
||||
from corrscope.layout import (
|
||||
LayoutConfig,
|
||||
RendererLayout,
|
||||
RegionSpec,
|
||||
Edges,
|
||||
Orientation,
|
||||
StereoOrientation,
|
||||
)
|
||||
from corrscope.renderer import RendererConfig, MatplotlibRenderer
|
||||
from tests.test_renderer import WIDTH, HEIGHT
|
||||
from corrscope.util import ceildiv
|
||||
from tests.test_renderer import WIDTH, HEIGHT, RENDER_Y_ZEROS
|
||||
|
||||
|
||||
def test_layout_config():
|
||||
|
@ -22,44 +35,184 @@ def test_layout_config():
|
|||
assert default.orientation == "h"
|
||||
|
||||
|
||||
# Small range to ensure many collisions.
|
||||
# max_value = 3 to allow for edge-free space in center.
|
||||
integers = hs.integers(-1, 3)
|
||||
|
||||
|
||||
@given(nrows=integers, ncols=integers, row=integers, col=integers)
|
||||
@settings(max_examples=500)
|
||||
def test_edges(nrows: int, ncols: int, row: int, col: int):
|
||||
if not (nrows > 0 and ncols > 0 and 0 <= row < nrows and 0 <= col < ncols):
|
||||
with pytest.raises(ValueError):
|
||||
edges = Edges.at(nrows, ncols, row, col)
|
||||
return
|
||||
|
||||
edges = Edges.at(nrows, ncols, row, col)
|
||||
assert bool(edges & Edges.Left) == (col == 0)
|
||||
assert bool(edges & Edges.Right) == (col == ncols - 1)
|
||||
assert bool(edges & Edges.Top) == (row == 0)
|
||||
assert bool(edges & Edges.Bottom) == (row == nrows - 1)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lcfg", [LayoutConfig(ncols=2), LayoutConfig(nrows=8)])
|
||||
@pytest.mark.parametrize("region_type", [str, tuple, list])
|
||||
def test_hlayout(lcfg, region_type):
|
||||
def test_hlayout(lcfg):
|
||||
nplots = 15
|
||||
layout = RendererLayout(lcfg, nplots)
|
||||
layout = RendererLayout(lcfg, [1] * nplots)
|
||||
|
||||
assert layout.ncols == 2
|
||||
assert layout.nrows == 8
|
||||
assert layout.wave_ncol == 2
|
||||
assert layout.wave_nrow == 8
|
||||
|
||||
regions = layout.arrange(lambda row, col: region_type((row, col)))
|
||||
assert len(regions) == nplots
|
||||
region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg)
|
||||
assert len(region2d) == nplots
|
||||
for i, regions in enumerate(region2d):
|
||||
assert len(regions) == 1, (i, len(regions))
|
||||
|
||||
np.testing.assert_equal(region2d[0][0].pos, (0, 0))
|
||||
np.testing.assert_equal(region2d[1][0].pos, (0, 1))
|
||||
np.testing.assert_equal(region2d[2][0].pos, (1, 0))
|
||||
|
||||
assert regions[0] == region_type((0, 0))
|
||||
assert regions[1] == region_type((0, 1))
|
||||
assert regions[2] == region_type((1, 0))
|
||||
m = nplots - 1
|
||||
assert regions[m] == region_type((m // 2, m % 2))
|
||||
npt.assert_equal(region2d[m][0].pos, (m // 2, m % 2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"lcfg",
|
||||
[LayoutConfig(ncols=3, orientation="v"), LayoutConfig(nrows=3, orientation="v")],
|
||||
)
|
||||
@pytest.mark.parametrize("region_type", [str, tuple, list])
|
||||
def test_vlayout(lcfg, region_type):
|
||||
def test_vlayout(lcfg):
|
||||
nplots = 7
|
||||
layout = RendererLayout(lcfg, nplots)
|
||||
layout = RendererLayout(lcfg, [1] * nplots)
|
||||
|
||||
assert layout.ncols == 3
|
||||
assert layout.nrows == 3
|
||||
assert layout.wave_ncol == 3
|
||||
assert layout.wave_nrow == 3
|
||||
|
||||
regions = layout.arrange(lambda row, col: region_type((row, col)))
|
||||
assert len(regions) == nplots
|
||||
region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg)
|
||||
assert len(region2d) == nplots
|
||||
for i, regions in enumerate(region2d):
|
||||
assert len(regions) == 1, (i, len(regions))
|
||||
|
||||
assert regions[0] == region_type((0, 0))
|
||||
assert regions[2] == region_type((2, 0))
|
||||
assert regions[3] == region_type((0, 1))
|
||||
assert regions[6] == region_type((0, 2))
|
||||
np.testing.assert_equal(region2d[0][0].pos, (0, 0))
|
||||
np.testing.assert_equal(region2d[2][0].pos, (2, 0))
|
||||
np.testing.assert_equal(region2d[3][0].pos, (0, 1))
|
||||
np.testing.assert_equal(region2d[6][0].pos, (0, 2))
|
||||
|
||||
|
||||
@given(
|
||||
wave_nchans=hs.lists(hs.integers(1, 10), min_size=1, max_size=100),
|
||||
orientation=hs.sampled_from(Orientation),
|
||||
stereo_orientation=hs.sampled_from(StereoOrientation),
|
||||
nrow_ncol=hs.integers(1, 100),
|
||||
is_nrows=hs.booleans(),
|
||||
)
|
||||
def test_stereo_layout(
|
||||
orientation: Orientation,
|
||||
stereo_orientation: StereoOrientation,
|
||||
wave_nchans: List[int],
|
||||
nrow_ncol: int,
|
||||
is_nrows: bool,
|
||||
):
|
||||
"""
|
||||
Not-entirely-rigorous test for layout computation.
|
||||
Mind-numbingly boring to write (and read?).
|
||||
|
||||
Honestly I prefer a good naming scheme in RendererLayout.arrange()
|
||||
over unit tests.
|
||||
|
||||
- This is a regression test...
|
||||
- And an obstacle to refactoring or feature development.
|
||||
"""
|
||||
# region Setup
|
||||
if is_nrows:
|
||||
nrows = nrow_ncol
|
||||
ncols = None
|
||||
else:
|
||||
nrows = None
|
||||
ncols = nrow_ncol
|
||||
|
||||
lcfg = LayoutConfig(
|
||||
orientation=orientation,
|
||||
nrows=nrows,
|
||||
ncols=ncols,
|
||||
stereo_orientation=stereo_orientation,
|
||||
)
|
||||
nwaves = len(wave_nchans)
|
||||
layout = RendererLayout(lcfg, wave_nchans)
|
||||
# endregion
|
||||
|
||||
# Assert layout dimensions correct
|
||||
assert layout.wave_ncol == ncols or ceildiv(nwaves, nrows)
|
||||
assert layout.wave_nrow == nrows or ceildiv(nwaves, ncols)
|
||||
|
||||
region2d: List[List[RegionSpec]] = layout.arrange(lambda r_spec: r_spec)
|
||||
|
||||
# Loop through layout regions
|
||||
assert len(region2d) == len(wave_nchans)
|
||||
for wave_i, wave_chans in enumerate(region2d):
|
||||
stereo_nchan = wave_nchans[wave_i]
|
||||
assert len(wave_chans) == stereo_nchan
|
||||
|
||||
# Compute channel dims within wave.
|
||||
if stereo_orientation == StereoOrientation.overlay:
|
||||
chans_per_wave = [1, 1]
|
||||
elif stereo_orientation == StereoOrientation.v: # pos[0]++
|
||||
chans_per_wave = [stereo_nchan, 1]
|
||||
else:
|
||||
assert stereo_orientation == StereoOrientation.h # pos[1]++
|
||||
chans_per_wave = [1, stereo_nchan]
|
||||
|
||||
# Sanity-check position of channel 0 relative to origin (wave grid).
|
||||
assert (np.add.reduce(wave_chans[0].pos) != 0) == (wave_i != 0)
|
||||
npt.assert_equal(wave_chans[0].pos % chans_per_wave, 0)
|
||||
|
||||
for chan_j, chan in enumerate(wave_chans):
|
||||
# Assert 0 <= position < size.
|
||||
assert chan.pos.shape == chan.size.shape == (2,)
|
||||
assert (0 <= chan.pos).all()
|
||||
assert (chan.pos < chan.size).all()
|
||||
|
||||
# Sanity-check position of chan relative to origin (wave grid).
|
||||
npt.assert_equal(
|
||||
chan.pos // chans_per_wave, wave_chans[0].pos // chans_per_wave
|
||||
)
|
||||
|
||||
# Check position of region (relative to channel 0)
|
||||
chan_wave_pos = chan.pos - wave_chans[0].pos
|
||||
|
||||
if stereo_orientation == StereoOrientation.overlay:
|
||||
npt.assert_equal(chan_wave_pos, [0, 0])
|
||||
elif stereo_orientation == StereoOrientation.v: # pos[0]++
|
||||
npt.assert_equal(chan_wave_pos, [chan_j, 0])
|
||||
else:
|
||||
assert stereo_orientation == StereoOrientation.h # pos[1]++
|
||||
npt.assert_equal(chan_wave_pos, [0, chan_j])
|
||||
|
||||
# Check screen edges
|
||||
screen_edges = chan.screen_edges
|
||||
assert bool(screen_edges & Edges.Top) == (chan.row == 0)
|
||||
assert bool(screen_edges & Edges.Left) == (chan.col == 0)
|
||||
assert bool(screen_edges & Edges.Bottom) == (chan.row == chan.nrow - 1)
|
||||
assert bool(screen_edges & Edges.Right) == (chan.col == chan.ncol - 1)
|
||||
|
||||
# Check stereo edges
|
||||
wave_edges = chan.wave_edges
|
||||
if stereo_orientation == StereoOrientation.overlay:
|
||||
assert wave_edges == ~Edges.NONE
|
||||
elif stereo_orientation == StereoOrientation.v: # pos[0]++
|
||||
lr = Edges.Left | Edges.Right
|
||||
assert wave_edges & lr == lr
|
||||
assert bool(wave_edges & Edges.Top) == (chan.row % stereo_nchan == 0)
|
||||
assert bool(wave_edges & Edges.Bottom) == (
|
||||
(chan.row + 1) % stereo_nchan == 0
|
||||
)
|
||||
else:
|
||||
assert stereo_orientation == StereoOrientation.h # pos[1]++
|
||||
tb = Edges.Top | Edges.Bottom
|
||||
assert wave_edges & tb == tb
|
||||
assert bool(wave_edges & Edges.Left) == (chan.col % stereo_nchan == 0)
|
||||
assert bool(wave_edges & Edges.Right) == (
|
||||
(chan.col + 1) % stereo_nchan == 0
|
||||
)
|
||||
|
||||
|
||||
def test_renderer_layout():
|
||||
|
@ -69,21 +222,9 @@ def test_renderer_layout():
|
|||
nplots = 15
|
||||
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
|
||||
r.render_frame([RENDER_Y_ZEROS] * nplots)
|
||||
layout = r.layout
|
||||
|
||||
# 2 columns, 8 rows
|
||||
assert r.layout.ncols == 2
|
||||
assert r.layout.nrows == 8
|
||||
|
||||
|
||||
# Test EdgeFinder
|
||||
|
||||
|
||||
def test_edge_finder():
|
||||
regions2d = np.arange(24).reshape(3, 8)
|
||||
edges = EdgeFinder(regions2d)
|
||||
|
||||
# Check borders
|
||||
np.testing.assert_equal(edges.lefts, regions2d[:, 0])
|
||||
np.testing.assert_equal(edges.rights, regions2d[:, -1])
|
||||
np.testing.assert_equal(edges.tops, regions2d[0, :])
|
||||
np.testing.assert_equal(edges.bottoms, regions2d[-1, :])
|
||||
assert layout.wave_ncol == 2
|
||||
assert layout.wave_nrow == 8
|
||||
|
|
|
@ -23,11 +23,8 @@ from corrscope.outputs import (
|
|||
Stop,
|
||||
)
|
||||
from corrscope.renderer import RendererConfig, MatplotlibRenderer
|
||||
from corrscope.settings.paths import MissingFFmpegError
|
||||
from tests.test_renderer import ALL_ZEROS
|
||||
from tests.test_renderer import RENDER_Y_ZEROS, WIDTH, HEIGHT
|
||||
|
||||
WIDTH = 192
|
||||
HEIGHT = 108
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pytest_mock
|
||||
|
@ -90,7 +87,7 @@ def test_render_output():
|
|||
renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1, channel_cfgs=None)
|
||||
out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG)
|
||||
|
||||
renderer.render_frame([ALL_ZEROS])
|
||||
renderer.render_frame([RENDER_Y_ZEROS])
|
||||
out.write_frame(renderer.get_frame())
|
||||
|
||||
assert out.close() == 0
|
||||
|
|
|
@ -1,38 +1,47 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import matplotlib.colors
|
||||
import numpy as np
|
||||
import pytest
|
||||
from matplotlib.colors import to_rgb
|
||||
|
||||
from corrscope.channel import ChannelConfig
|
||||
from corrscope.corrscope import CorrScope, default_config, Arguments
|
||||
from corrscope.layout import LayoutConfig
|
||||
from corrscope.outputs import RGB_DEPTH
|
||||
from corrscope.outputs import RGB_DEPTH, FFplayOutputConfig
|
||||
from corrscope.renderer import RendererConfig, MatplotlibRenderer
|
||||
from corrscope.wave import Flatten
|
||||
|
||||
WIDTH = 640
|
||||
HEIGHT = 360
|
||||
if TYPE_CHECKING:
|
||||
import pytest_mock
|
||||
|
||||
ALL_ZEROS = np.array([0, 0])
|
||||
WIDTH = 64
|
||||
HEIGHT = 64
|
||||
|
||||
RENDER_Y_ZEROS = np.zeros((2, 1))
|
||||
RENDER_Y_STEREO = np.zeros((2, 2))
|
||||
OPACITY = 2 / 3
|
||||
|
||||
all_colors = pytest.mark.parametrize(
|
||||
"bg_str,fg_str,grid_str",
|
||||
"bg_str,fg_str,grid_str,data",
|
||||
[
|
||||
("#000000", "#ffffff", None),
|
||||
("#ffffff", "#000000", None),
|
||||
("#0000aa", "#aaaa00", None),
|
||||
("#aaaa00", "#0000aa", None),
|
||||
# Enabling gridlines enables Axes rectangles.
|
||||
# Make sure they don't draw *over* the global figure background.
|
||||
("#0000aa", "#aaaa00", "#ff00ff"), # beautiful magenta gridlines
|
||||
("#aaaa00", "#0000aa", "#ff00ff"),
|
||||
("#000000", "#ffffff", None, RENDER_Y_ZEROS),
|
||||
("#ffffff", "#000000", None, RENDER_Y_ZEROS),
|
||||
("#0000aa", "#aaaa00", None, RENDER_Y_ZEROS),
|
||||
("#aaaa00", "#0000aa", None, RENDER_Y_ZEROS),
|
||||
# Enabling ~~beautiful magenta~~ gridlines enables Axes rectangles.
|
||||
# Make sure bg is disabled, so they don't overwrite global figure background.
|
||||
("#0000aa", "#aaaa00", "#ff00ff", RENDER_Y_ZEROS),
|
||||
("#aaaa00", "#0000aa", "#ff00ff", RENDER_Y_ZEROS),
|
||||
("#0000aa", "#aaaa00", "#ff00ff", RENDER_Y_STEREO),
|
||||
("#aaaa00", "#0000aa", "#ff00ff", RENDER_Y_STEREO),
|
||||
],
|
||||
)
|
||||
|
||||
nplots = 2
|
||||
NPLOTS = 2
|
||||
|
||||
|
||||
@all_colors
|
||||
def test_default_colors(bg_str, fg_str, grid_str):
|
||||
def test_default_colors(bg_str, fg_str, grid_str, data):
|
||||
""" Test the default background/foreground colors. """
|
||||
cfg = RendererConfig(
|
||||
WIDTH,
|
||||
|
@ -40,23 +49,24 @@ def test_default_colors(bg_str, fg_str, grid_str):
|
|||
bg_color=bg_str,
|
||||
init_line_color=fg_str,
|
||||
grid_color=grid_str,
|
||||
stereo_grid_opacity=OPACITY,
|
||||
line_width=2.0,
|
||||
antialiasing=False,
|
||||
)
|
||||
lcfg = LayoutConfig()
|
||||
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
|
||||
verify(r, bg_str, fg_str, grid_str)
|
||||
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, None)
|
||||
verify(r, bg_str, fg_str, grid_str, data)
|
||||
|
||||
# Ensure default ChannelConfig(line_color=None) does not override line color
|
||||
chan = ChannelConfig(wav_path="")
|
||||
channels = [chan] * nplots
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots, channels)
|
||||
verify(r, bg_str, fg_str, grid_str)
|
||||
channels = [chan] * NPLOTS
|
||||
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
|
||||
verify(r, bg_str, fg_str, grid_str, data)
|
||||
|
||||
|
||||
@all_colors
|
||||
def test_line_colors(bg_str, fg_str, grid_str):
|
||||
def test_line_colors(bg_str, fg_str, grid_str, data):
|
||||
""" Test channel-specific line color overrides """
|
||||
cfg = RendererConfig(
|
||||
WIDTH,
|
||||
|
@ -64,30 +74,44 @@ def test_line_colors(bg_str, fg_str, grid_str):
|
|||
bg_color=bg_str,
|
||||
init_line_color="#888888",
|
||||
grid_color=grid_str,
|
||||
stereo_grid_opacity=OPACITY,
|
||||
line_width=2.0,
|
||||
antialiasing=False,
|
||||
)
|
||||
lcfg = LayoutConfig()
|
||||
|
||||
chan = ChannelConfig(wav_path="", line_color=fg_str)
|
||||
channels = [chan] * nplots
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots, channels)
|
||||
verify(r, bg_str, fg_str, grid_str)
|
||||
channels = [chan] * NPLOTS
|
||||
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
|
||||
verify(r, bg_str, fg_str, grid_str, data)
|
||||
|
||||
|
||||
def verify(r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str]):
|
||||
r.render_frame([ALL_ZEROS] * nplots)
|
||||
TOLERANCE = 3
|
||||
|
||||
|
||||
def verify(
|
||||
r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str], data: np.ndarray
|
||||
):
|
||||
r.render_frame([data] * NPLOTS)
|
||||
frame_colors: np.ndarray = np.frombuffer(r.get_frame(), dtype=np.uint8).reshape(
|
||||
(-1, RGB_DEPTH)
|
||||
)
|
||||
|
||||
bg_u8 = [round(c * 255) for c in to_rgb(bg_str)]
|
||||
fg_u8 = [round(c * 255) for c in to_rgb(fg_str)]
|
||||
bg_u8 = to_rgb(bg_str)
|
||||
fg_u8 = to_rgb(fg_str)
|
||||
all_colors = [bg_u8, fg_u8]
|
||||
|
||||
if grid_str:
|
||||
grid_u8 = [round(c * 255) for c in to_rgb(grid_str)]
|
||||
grid_u8 = to_rgb(grid_str)
|
||||
all_colors.append(grid_u8)
|
||||
else:
|
||||
grid_u8 = bg_u8
|
||||
|
||||
assert (data.shape[1] > 1) == (data is RENDER_Y_STEREO)
|
||||
is_stereo = data.shape[1] > 1
|
||||
if is_stereo:
|
||||
stereo_grid_u8 = (grid_u8 * OPACITY + bg_u8 * (1 - OPACITY)).astype(int)
|
||||
all_colors.append(stereo_grid_u8)
|
||||
|
||||
# Ensure background is correct
|
||||
bg_frame = frame_colors[0]
|
||||
|
@ -104,5 +128,36 @@ def verify(r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str]):
|
|||
if grid_str:
|
||||
assert np.prod(frame_colors == grid_u8, axis=-1).any(), "Missing grid_str"
|
||||
|
||||
# Ensure stereo grid color is present
|
||||
if is_stereo:
|
||||
assert (
|
||||
np.min(np.sum(np.abs(frame_colors - stereo_grid_u8), axis=-1)) < TOLERANCE
|
||||
), "Missing stereo gridlines"
|
||||
|
||||
assert (np.amax(frame_colors, axis=0) == np.amax(all_colors, axis=0)).all()
|
||||
assert (np.amin(frame_colors, axis=0) == np.amin(all_colors, axis=0)).all()
|
||||
|
||||
|
||||
def to_rgb(c) -> np.ndarray:
|
||||
to_rgb = matplotlib.colors.to_rgb
|
||||
return np.array([round(c * 255) for c in to_rgb(c)], dtype=int)
|
||||
|
||||
|
||||
# Stereo *renderer* integration tests.
|
||||
def test_stereo_render_integration(mocker: "pytest_mock.MockFixture"):
|
||||
"""Ensure corrscope plays/renders in stereo, without crashing."""
|
||||
|
||||
# Stub out FFplay output.
|
||||
mocker.patch.object(FFplayOutputConfig, "cls")
|
||||
|
||||
# Render in stereo.
|
||||
cfg = default_config(
|
||||
channels=[ChannelConfig("tests/stereo in-phase.wav")],
|
||||
render_stereo=Flatten.Stereo,
|
||||
end_time=0.5, # Reduce test duration
|
||||
render=RendererConfig(WIDTH, HEIGHT),
|
||||
)
|
||||
|
||||
# Make sure it doesn't crash.
|
||||
corr = CorrScope(cfg, Arguments(".", [FFplayOutputConfig()]))
|
||||
corr.play()
|
||||
|
|
|
@ -79,6 +79,7 @@ AllFlattens = Flatten.__members__.values()
|
|||
|
||||
|
||||
@pytest.mark.parametrize("flatten", AllFlattens)
|
||||
@pytest.mark.parametrize("return_channels", [False, True])
|
||||
@pytest.mark.parametrize(
|
||||
"path,nchan,peaks",
|
||||
[
|
||||
|
@ -88,19 +89,30 @@ AllFlattens = Flatten.__members__.values()
|
|||
],
|
||||
)
|
||||
def test_stereo_flatten_modes(
|
||||
flatten: Flatten, path: str, nchan: int, peaks: Sequence[float]
|
||||
flatten: Flatten,
|
||||
return_channels: bool,
|
||||
path: str,
|
||||
nchan: int,
|
||||
peaks: Sequence[float],
|
||||
):
|
||||
"""Ensures all Flatten modes are handled properly
|
||||
for stereo and mono signals."""
|
||||
|
||||
# return_channels=False <-> triggering.
|
||||
# flatten=stereo -> rendering.
|
||||
# These conditions do not currently coexist.
|
||||
# if not return_channels and flatten == Flatten.Stereo:
|
||||
# return
|
||||
|
||||
assert nchan == len(peaks)
|
||||
wave = Wave(path)
|
||||
|
||||
if flatten not in Flatten.modes:
|
||||
with pytest.raises(CorrError):
|
||||
wave.with_flatten(flatten)
|
||||
wave.with_flatten(flatten, return_channels)
|
||||
return
|
||||
else:
|
||||
wave = wave.with_flatten(flatten)
|
||||
wave = wave.with_flatten(flatten, return_channels)
|
||||
|
||||
nsamp = wave.nsamp
|
||||
data = wave[:]
|
||||
|
@ -111,7 +123,10 @@ def test_stereo_flatten_modes(
|
|||
for chan_data, peak in zip(data.T, peaks):
|
||||
assert_full_scale(chan_data, peak)
|
||||
else:
|
||||
assert data.shape == (nsamp,)
|
||||
if return_channels:
|
||||
assert data.shape == (nsamp, 1)
|
||||
else:
|
||||
assert data.shape == (nsamp,)
|
||||
|
||||
# If DiffAvg and in-phase, L-R=0.
|
||||
if flatten == Flatten.DiffAvg:
|
||||
|
|
Ładowanie…
Reference in New Issue