kopia lustrzana https://github.com/corrscope/corrscope
rodzic
d1bdf0cb34
commit
0dbba8ac09
|
@ -62,8 +62,8 @@ class Channel:
|
||||||
tflat = coalesce(cfg.trigger_stereo, corr_cfg.trigger_stereo)
|
tflat = coalesce(cfg.trigger_stereo, corr_cfg.trigger_stereo)
|
||||||
rflat = coalesce(cfg.render_stereo, corr_cfg.render_stereo)
|
rflat = coalesce(cfg.render_stereo, corr_cfg.render_stereo)
|
||||||
|
|
||||||
self.trigger_wave = wave.with_flatten(tflat)
|
self.trigger_wave = wave.with_flatten(tflat, return_channels=False)
|
||||||
self.render_wave = wave.with_flatten(rflat)
|
self.render_wave = wave.with_flatten(rflat, return_channels=True)
|
||||||
|
|
||||||
# `subsampling` increases `stride` and decreases `nsamp`.
|
# `subsampling` increases `stride` and decreases `nsamp`.
|
||||||
# `width` increases `stride` without changing `nsamp`.
|
# `width` increases `stride` without changing `nsamp`.
|
||||||
|
|
|
@ -1,16 +1,41 @@
|
||||||
from typing import Optional, TypeVar, Callable, List, Generic, Tuple
|
import collections
|
||||||
|
import enum
|
||||||
|
from enum import auto
|
||||||
|
from typing import Optional, TypeVar, Callable, List, Iterable
|
||||||
|
|
||||||
|
import attr
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from corrscope.config import DumpableAttrs, CorrError
|
from corrscope.config import DumpableAttrs, CorrError, DumpEnumAsStr
|
||||||
from corrscope.util import ceildiv
|
from corrscope.util import ceildiv
|
||||||
|
|
||||||
|
|
||||||
class LayoutConfig(DumpableAttrs, always_dump="orientation"):
|
class Orientation(str, DumpEnumAsStr):
|
||||||
orientation: str = "h"
|
h = "h"
|
||||||
|
v = "v"
|
||||||
|
|
||||||
|
|
||||||
|
class StereoOrientation(str, DumpEnumAsStr):
|
||||||
|
h = "h"
|
||||||
|
v = "v"
|
||||||
|
overlay = "overlay"
|
||||||
|
|
||||||
|
|
||||||
|
assert Orientation.h == StereoOrientation.h
|
||||||
|
H = Orientation.h
|
||||||
|
V = Orientation.v
|
||||||
|
OVERLAY = StereoOrientation.overlay
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutConfig(DumpableAttrs, always_dump="orientation stereo_orientation"):
|
||||||
|
orientation: Orientation = attr.ib(default="h", converter=Orientation)
|
||||||
nrows: Optional[int] = None
|
nrows: Optional[int] = None
|
||||||
ncols: Optional[int] = None
|
ncols: Optional[int] = None
|
||||||
|
|
||||||
|
stereo_orientation: StereoOrientation = attr.ib(
|
||||||
|
default="h", converter=StereoOrientation
|
||||||
|
)
|
||||||
|
|
||||||
def __attrs_post_init__(self) -> None:
|
def __attrs_post_init__(self) -> None:
|
||||||
if not self.nrows:
|
if not self.nrows:
|
||||||
self.nrows = None
|
self.nrows = None
|
||||||
|
@ -24,31 +49,90 @@ class LayoutConfig(DumpableAttrs, always_dump="orientation"):
|
||||||
self.ncols = 1
|
self.ncols = 1
|
||||||
|
|
||||||
|
|
||||||
|
class Edges(enum.Flag):
|
||||||
|
NONE = 0
|
||||||
|
Top = auto()
|
||||||
|
Left = auto()
|
||||||
|
Bottom = auto()
|
||||||
|
Right = auto()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def at(nrows: int, ncols: int, row: int, col: int):
|
||||||
|
if not nrows > 0:
|
||||||
|
raise ValueError(f"invalid nrows={nrows}, must be positive")
|
||||||
|
if not ncols > 0:
|
||||||
|
raise ValueError(f"invalid ncols={ncols}, must be positive")
|
||||||
|
if not 0 <= row < nrows:
|
||||||
|
raise ValueError(f"invalid row={row} not in [0 .. nrows={nrows})")
|
||||||
|
if not 0 <= col < ncols:
|
||||||
|
raise ValueError(f"invalid col={col} not in [0 .. ncols={ncols})")
|
||||||
|
|
||||||
|
ret = Edges.NONE
|
||||||
|
if row == 0:
|
||||||
|
ret |= Edges.Top
|
||||||
|
if row + 1 == nrows:
|
||||||
|
ret |= Edges.Bottom
|
||||||
|
if col == 0:
|
||||||
|
ret |= Edges.Left
|
||||||
|
if col + 1 == ncols:
|
||||||
|
ret |= Edges.Right
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def attr_idx_property(key: str, idx: int) -> property:
|
||||||
|
@property
|
||||||
|
def inner(self: "RegionSpec"):
|
||||||
|
return getattr(self, key)[idx]
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
@attr.dataclass
|
||||||
|
class RegionSpec:
|
||||||
|
"""
|
||||||
|
- Origin is located at top-left.
|
||||||
|
- Row 0 = top.
|
||||||
|
- Row nrows-1 = bottom.
|
||||||
|
- Col 0 = left.
|
||||||
|
- Col ncols-1 = right.
|
||||||
|
"""
|
||||||
|
|
||||||
|
size: np.ndarray
|
||||||
|
pos: np.ndarray
|
||||||
|
|
||||||
|
nrow = attr_idx_property("size", 0)
|
||||||
|
ncol = attr_idx_property("size", 1)
|
||||||
|
row = attr_idx_property("pos", 0)
|
||||||
|
col = attr_idx_property("pos", 1)
|
||||||
|
|
||||||
|
screen_edges: "Edges"
|
||||||
|
wave_edges: "Edges"
|
||||||
|
|
||||||
|
|
||||||
Region = TypeVar("Region")
|
Region = TypeVar("Region")
|
||||||
RegionFactory = Callable[[int, int], Region] # f(row, column) -> Region
|
RegionFactory = Callable[[RegionSpec], Region] # f(row, column) -> Region
|
||||||
|
|
||||||
|
|
||||||
class RendererLayout:
|
class RendererLayout:
|
||||||
VALID_ORIENTATIONS = ["h", "v"]
|
def __init__(self, cfg: LayoutConfig, wave_nchans: List[int]):
|
||||||
|
|
||||||
def __init__(self, cfg: LayoutConfig, nplots: int):
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
self.nplots = nplots
|
self.nwaves = len(wave_nchans)
|
||||||
|
self.wave_nchans = wave_nchans
|
||||||
# Setup layout
|
|
||||||
self.nrows, self.ncols = self._calc_layout()
|
|
||||||
|
|
||||||
self.orientation = cfg.orientation
|
self.orientation = cfg.orientation
|
||||||
if self.orientation not in self.VALID_ORIENTATIONS:
|
self.stereo_orientation = cfg.stereo_orientation
|
||||||
raise CorrError(
|
|
||||||
f"Invalid orientation {self.orientation} not in "
|
|
||||||
f"{self.VALID_ORIENTATIONS}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _calc_layout(self) -> Tuple[int, int]:
|
# Setup layout
|
||||||
|
self._calc_layout()
|
||||||
|
|
||||||
|
# Shape of wave slots
|
||||||
|
wave_nrow: int
|
||||||
|
wave_ncol: int
|
||||||
|
|
||||||
|
def _calc_layout(self) -> None:
|
||||||
"""
|
"""
|
||||||
Inputs: self.cfg, self.waves
|
Inputs: self.cfg, self.stereo_nchan
|
||||||
:return: (nrows, ncols)
|
Outputs: self.wave_nrow, ncol
|
||||||
"""
|
"""
|
||||||
cfg = self.cfg
|
cfg = self.cfg
|
||||||
|
|
||||||
|
@ -56,7 +140,7 @@ class RendererLayout:
|
||||||
nrows = cfg.nrows
|
nrows = cfg.nrows
|
||||||
if nrows is None:
|
if nrows is None:
|
||||||
raise ValueError("impossible cfg: nrows is None and true")
|
raise ValueError("impossible cfg: nrows is None and true")
|
||||||
ncols = ceildiv(self.nplots, nrows)
|
ncols = ceildiv(self.nwaves, nrows)
|
||||||
else:
|
else:
|
||||||
if cfg.ncols is None:
|
if cfg.ncols is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -64,38 +148,110 @@ class RendererLayout:
|
||||||
"(__attrs_post_init__ not called?)"
|
"(__attrs_post_init__ not called?)"
|
||||||
)
|
)
|
||||||
ncols = cfg.ncols
|
ncols = cfg.ncols
|
||||||
nrows = ceildiv(self.nplots, ncols)
|
nrows = ceildiv(self.nwaves, ncols)
|
||||||
|
|
||||||
return nrows, ncols
|
self.wave_nrow = nrows
|
||||||
|
self.wave_ncol = ncols
|
||||||
|
|
||||||
def arrange(self, region_factory: RegionFactory[Region]) -> List[Region]:
|
def arrange(self, region_factory: RegionFactory[Region]) -> List[List[Region]]:
|
||||||
""" Generates an array of regions.
|
|
||||||
|
|
||||||
index, row, column are fed into region_factory in a row-major order [row][col].
|
|
||||||
The results are possibly reshaped into column-major order [col][row].
|
|
||||||
"""
|
"""
|
||||||
nspaces = self.nrows * self.ncols
|
(row, column) are fed into region_factory in a row-major order [row][col].
|
||||||
inds = np.arange(nspaces)
|
Stereo channel pairs are extracted.
|
||||||
rows, cols = np.unravel_index(inds, (self.nrows, self.ncols))
|
The results are possibly reshaped into column-major order [col][row].
|
||||||
|
|
||||||
row_col = list(zip(rows, cols))
|
:return arr[wave][channel] = Region
|
||||||
regions = np.empty(len(row_col), dtype=object) # type: np.ndarray[Region]
|
"""
|
||||||
regions[:] = [region_factory(*rc) for rc in row_col]
|
|
||||||
|
|
||||||
regions2d = regions.reshape(
|
wave_spaces = self.wave_nrow * self.wave_ncol
|
||||||
(self.nrows, self.ncols)
|
inds = np.arange(wave_spaces)
|
||||||
) # type: np.ndarray[Region]
|
|
||||||
|
|
||||||
# if column major:
|
# Compute location of each wave.
|
||||||
if self.orientation == "v":
|
if self.orientation == V:
|
||||||
regions2d = regions2d.T
|
# column major
|
||||||
|
cols, rows = np.unravel_index(inds, (self.wave_ncol, self.wave_nrow))
|
||||||
|
else:
|
||||||
|
# row major
|
||||||
|
rows, cols = np.unravel_index(inds, (self.wave_nrow, self.wave_ncol))
|
||||||
|
|
||||||
return regions2d.flatten()[: self.nplots].tolist()
|
# Generate plot for each wave.chan. Leave unused slots empty.
|
||||||
|
region_wave_chan: List[List[Region]] = []
|
||||||
|
|
||||||
|
# The order of (rows, cols) has no effect.
|
||||||
|
for stereo_nchan, wave_row, wave_col in zip(self.wave_nchans, rows, cols):
|
||||||
|
# Wave = within Screen.
|
||||||
|
# Chan = within Wave, generate a plot.
|
||||||
|
# All arrays are [y, x] == [row, col].
|
||||||
|
|
||||||
|
# Wave dim/pos (within screen)
|
||||||
|
waves_per_screen = arr(self.wave_nrow, self.wave_ncol)
|
||||||
|
wave_screen_pos = arr(wave_row, wave_col)
|
||||||
|
del wave_row, wave_col
|
||||||
|
|
||||||
|
# Channel dim/pos (within wave)
|
||||||
|
chans_per_wave = arr(1, 1) # Mutated based on orientation
|
||||||
|
chan_wave_pos = arr(0, 0) # Mutated in for-chan loop.
|
||||||
|
|
||||||
|
# Distance between chans
|
||||||
|
dchan = arr(0, 0) # Mutated based on orientation
|
||||||
|
|
||||||
|
if self.stereo_orientation == V:
|
||||||
|
chans_per_wave[0] = stereo_nchan
|
||||||
|
dchan[0] = 1
|
||||||
|
elif self.stereo_orientation == H:
|
||||||
|
chans_per_wave[1] = stereo_nchan
|
||||||
|
dchan[1] = 1
|
||||||
|
else:
|
||||||
|
assert self.stereo_orientation == OVERLAY
|
||||||
|
|
||||||
|
# Channel dim/pos (within screen)
|
||||||
|
chans_per_screen = chans_per_wave * waves_per_screen
|
||||||
|
chan_screen_pos = chans_per_wave * wave_screen_pos
|
||||||
|
|
||||||
|
# Generate plots for each channel
|
||||||
|
region_chan: List[Region] = []
|
||||||
|
region_wave_chan.append(region_chan)
|
||||||
|
region = None
|
||||||
|
|
||||||
|
for chan in range(stereo_nchan):
|
||||||
|
assert (chan_wave_pos < chans_per_wave).all()
|
||||||
|
assert (chan_screen_pos < chans_per_screen).all()
|
||||||
|
|
||||||
|
# Generate plot (channel position in screen)
|
||||||
|
if region is None or dchan.any():
|
||||||
|
screen_edges = Edges.at(*chans_per_screen, *chan_screen_pos)
|
||||||
|
wave_edges = Edges.at(*chans_per_wave, *chan_wave_pos)
|
||||||
|
|
||||||
|
# Removing .copy() causes bugs if region_factory() holds
|
||||||
|
# mutable references.
|
||||||
|
chan_spec = RegionSpec(
|
||||||
|
chans_per_screen.copy(),
|
||||||
|
chan_screen_pos.copy(),
|
||||||
|
screen_edges,
|
||||||
|
wave_edges,
|
||||||
|
)
|
||||||
|
region = region_factory(chan_spec)
|
||||||
|
region_chan.append(region)
|
||||||
|
|
||||||
|
# Move to next channel position
|
||||||
|
chan_screen_pos += dchan
|
||||||
|
chan_wave_pos += dchan
|
||||||
|
|
||||||
|
assert len(region_wave_chan) == self.nwaves
|
||||||
|
return region_wave_chan
|
||||||
|
|
||||||
|
|
||||||
class EdgeFinder(Generic[Region]):
|
def arr(*args):
|
||||||
def __init__(self, regions2d: np.ndarray):
|
return np.array(args)
|
||||||
self.tops: List[Region] = regions2d[0, :].tolist()
|
|
||||||
self.bottoms: List[Region] = regions2d[-1, :].tolist()
|
|
||||||
self.lefts: List[Region] = regions2d[:, 0].tolist()
|
T = TypeVar("T")
|
||||||
self.rights: List[Region] = regions2d[:, -1].tolist()
|
|
||||||
|
|
||||||
|
def unique_by_id(items: Iterable[T]) -> List[T]:
|
||||||
|
seen = collections.OrderedDict()
|
||||||
|
|
||||||
|
for item in items:
|
||||||
|
if id(item) not in seen:
|
||||||
|
seen[id(item)] = item
|
||||||
|
|
||||||
|
return list(seen.values())
|
||||||
|
|
|
@ -4,10 +4,17 @@ from typing import Optional, List, TYPE_CHECKING, Any
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
import matplotlib
|
import matplotlib
|
||||||
|
import matplotlib.colors
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from corrscope.config import DumpableAttrs, with_units
|
from corrscope.config import DumpableAttrs, with_units
|
||||||
from corrscope.layout import RendererLayout, LayoutConfig, EdgeFinder
|
from corrscope.layout import (
|
||||||
|
RendererLayout,
|
||||||
|
LayoutConfig,
|
||||||
|
unique_by_id,
|
||||||
|
RegionSpec,
|
||||||
|
Edges,
|
||||||
|
)
|
||||||
from corrscope.outputs import RGB_DEPTH, ByteBuffer
|
from corrscope.outputs import RGB_DEPTH, ByteBuffer
|
||||||
from corrscope.util import coalesce
|
from corrscope.util import coalesce
|
||||||
|
|
||||||
|
@ -58,7 +65,10 @@ class RendererConfig(DumpableAttrs, always_dump="*"):
|
||||||
|
|
||||||
bg_color: str = "#000000"
|
bg_color: str = "#000000"
|
||||||
init_line_color: str = default_color()
|
init_line_color: str = default_color()
|
||||||
|
|
||||||
grid_color: Optional[str] = None
|
grid_color: Optional[str] = None
|
||||||
|
stereo_grid_opacity: float = 0.5
|
||||||
|
|
||||||
midline_color: Optional[str] = None
|
midline_color: Optional[str] = None
|
||||||
v_midline: bool = False
|
v_midline: bool = False
|
||||||
h_midline: bool = False
|
h_midline: bool = False
|
||||||
|
@ -99,8 +109,8 @@ class Renderer(ABC):
|
||||||
channel_cfgs: Optional[List["ChannelConfig"]],
|
channel_cfgs: Optional[List["ChannelConfig"]],
|
||||||
):
|
):
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
self.lcfg = lcfg
|
||||||
self.nplots = nplots
|
self.nplots = nplots
|
||||||
self.layout = RendererLayout(lcfg, nplots)
|
|
||||||
|
|
||||||
# Load line colors.
|
# Load line colors.
|
||||||
if channel_cfgs is not None:
|
if channel_cfgs is not None:
|
||||||
|
@ -165,16 +175,20 @@ class MatplotlibRenderer(Renderer):
|
||||||
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
|
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
|
||||||
)
|
)
|
||||||
|
|
||||||
# Flat array of nrows*ncols elements, ordered by cfg.rows_first.
|
|
||||||
self._fig: "Figure"
|
self._fig: "Figure"
|
||||||
self._axes: List["Axes"] # set by set_layout()
|
|
||||||
self._lines: Optional[List["Line2D"]] = None # set by render_frame() first call
|
|
||||||
|
|
||||||
self._set_layout() # mutates self
|
# _axes2d[wave][chan] = Axes
|
||||||
|
self._axes2d: List[List["Axes"]] # set by set_layout()
|
||||||
|
|
||||||
|
# _lines2d[wave][chan] = Line2D
|
||||||
|
self._lines2d: List[List[Line2D]] = []
|
||||||
|
self._lines_flat: List["Line2D"] = []
|
||||||
|
|
||||||
transparent = "#00000000"
|
transparent = "#00000000"
|
||||||
|
|
||||||
def _set_layout(self) -> None:
|
layout: RendererLayout
|
||||||
|
|
||||||
|
def _set_layout(self, wave_nchans: List[int]) -> None:
|
||||||
"""
|
"""
|
||||||
Creates a flat array of Matplotlib Axes, with the new layout.
|
Creates a flat array of Matplotlib Axes, with the new layout.
|
||||||
Opens a window showing the Figure (and Axes).
|
Opens a window showing the Figure (and Axes).
|
||||||
|
@ -183,6 +197,8 @@ class MatplotlibRenderer(Renderer):
|
||||||
Outputs: self.nrows, self.ncols, self.axes
|
Outputs: self.nrows, self.ncols, self.axes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
self.layout = RendererLayout(self.lcfg, wave_nchans)
|
||||||
|
|
||||||
# Create Axes
|
# Create Axes
|
||||||
# https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html
|
# https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html
|
||||||
if hasattr(self, "_fig"):
|
if hasattr(self, "_fig"):
|
||||||
|
@ -190,24 +206,24 @@ class MatplotlibRenderer(Renderer):
|
||||||
# plt.close(self.fig)
|
# plt.close(self.fig)
|
||||||
|
|
||||||
grid_color = self.cfg.grid_color
|
grid_color = self.cfg.grid_color
|
||||||
axes2d: np.ndarray["Axes"]
|
|
||||||
self._fig = Figure()
|
self._fig = Figure()
|
||||||
FigureCanvasAgg(self._fig)
|
FigureCanvasAgg(self._fig)
|
||||||
|
|
||||||
axes2d = self._fig.subplots(
|
# RegionFactory
|
||||||
self.layout.nrows,
|
def axes_factory(r: RegionSpec) -> "Axes":
|
||||||
self.layout.ncols,
|
width = 1 / r.ncol
|
||||||
squeeze=False,
|
left = r.col / r.ncol
|
||||||
# Remove axis ticks (which slow down rendering)
|
assert 0 <= left < 1
|
||||||
subplot_kw=dict(xticks=[], yticks=[]),
|
|
||||||
# Remove gaps between Axes TODO borders shouldn't be half-visible
|
|
||||||
gridspec_kw=dict(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0),
|
|
||||||
)
|
|
||||||
|
|
||||||
ax: "Axes"
|
height = 1 / r.nrow
|
||||||
if grid_color:
|
bottom = (r.nrow - r.row - 1) / r.nrow
|
||||||
# Initialize borders
|
assert 0 <= bottom < 1
|
||||||
for ax in axes2d.flatten():
|
|
||||||
|
# Disabling xticks/yticks is unnecessary, since we hide Axises.
|
||||||
|
ax = self._fig.add_axes([left, bottom, width, height], xticks=[], yticks=[])
|
||||||
|
|
||||||
|
if grid_color:
|
||||||
|
# Initialize borders
|
||||||
# Hide Axises
|
# Hide Axises
|
||||||
# (drawing them is very slow, and we disable ticks+labels anyway)
|
# (drawing them is very slow, and we disable ticks+labels anyway)
|
||||||
ax.get_xaxis().set_visible(False)
|
ax.get_xaxis().set_visible(False)
|
||||||
|
@ -223,25 +239,44 @@ class MatplotlibRenderer(Renderer):
|
||||||
for spine in ax.spines.values():
|
for spine in ax.spines.values():
|
||||||
spine.set_color(grid_color)
|
spine.set_color(grid_color)
|
||||||
|
|
||||||
# gridspec_kw indexes from bottom-left corner.
|
def hide(key: str):
|
||||||
# Only show bottom-left borders (x=0, y=0)
|
ax.spines[key].set_visible(False)
|
||||||
ax.spines["top"].set_visible(False)
|
|
||||||
ax.spines["right"].set_visible(False)
|
|
||||||
|
|
||||||
# Hide bottom-left edges for speed.
|
# Hide all axes except bottom-right.
|
||||||
edge_axes: EdgeFinder["Axes"] = EdgeFinder(axes2d)
|
hide("top")
|
||||||
for ax in edge_axes.bottoms:
|
hide("left")
|
||||||
ax.spines["bottom"].set_visible(False)
|
|
||||||
for ax in edge_axes.lefts:
|
|
||||||
ax.spines["left"].set_visible(False)
|
|
||||||
|
|
||||||
else:
|
# If bottom of screen, hide bottom. If right of screen, hide right.
|
||||||
# Remove Axis from Axes
|
if r.screen_edges & Edges.Bottom:
|
||||||
for ax in axes2d.flatten():
|
hide("bottom")
|
||||||
|
if r.screen_edges & Edges.Right:
|
||||||
|
hide("right")
|
||||||
|
|
||||||
|
# Dim stereo gridlines
|
||||||
|
if self.cfg.stereo_grid_opacity > 0:
|
||||||
|
dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
|
||||||
|
dim_color[-1] = self.cfg.stereo_grid_opacity
|
||||||
|
|
||||||
|
def dim(key: str):
|
||||||
|
ax.spines[key].set_color(dim_color)
|
||||||
|
|
||||||
|
else:
|
||||||
|
dim = hide
|
||||||
|
|
||||||
|
# If not bottom of wave, dim bottom. If not right of wave, dim right.
|
||||||
|
if not r.wave_edges & Edges.Bottom:
|
||||||
|
dim("bottom")
|
||||||
|
if not r.wave_edges & Edges.Right:
|
||||||
|
dim("right")
|
||||||
|
|
||||||
|
else:
|
||||||
ax.set_axis_off()
|
ax.set_axis_off()
|
||||||
|
|
||||||
# Generate arrangement (using nplots, cfg.orientation)
|
return ax
|
||||||
self._axes: List[Axes] = self.layout.arrange(lambda row, col: axes2d[row, col])
|
|
||||||
|
# Generate arrangement (using self.lcfg, wave_nchans)
|
||||||
|
# _axes2d[wave][chan] = Axes
|
||||||
|
self._axes2d = self.layout.arrange(axes_factory)
|
||||||
|
|
||||||
# Setup figure geometry
|
# Setup figure geometry
|
||||||
self._fig.set_dpi(DPI)
|
self._fig.set_dpi(DPI)
|
||||||
|
@ -255,45 +290,66 @@ class MatplotlibRenderer(Renderer):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize axes and draw waveform data
|
# Initialize axes and draw waveform data
|
||||||
if self._lines is None:
|
if not self._lines2d:
|
||||||
|
assert len(datas[0].shape) == 2, datas[0].shape
|
||||||
|
|
||||||
|
wave_nchans = [data.shape[1] for data in datas]
|
||||||
|
self._set_layout(wave_nchans)
|
||||||
|
|
||||||
cfg = self.cfg
|
cfg = self.cfg
|
||||||
|
|
||||||
# Setup background/axes
|
# Setup background/axes
|
||||||
self._fig.set_facecolor(cfg.bg_color)
|
self._fig.set_facecolor(cfg.bg_color)
|
||||||
for idx, data in enumerate(datas):
|
for idx, wave_data in enumerate(datas):
|
||||||
ax = self._axes[idx]
|
wave_axes = self._axes2d[idx]
|
||||||
max_x = len(data) - 1
|
for ax in unique_by_id(wave_axes):
|
||||||
ax.set_xlim(0, max_x)
|
max_x = len(wave_data) - 1
|
||||||
ax.set_ylim(-1, 1)
|
ax.set_xlim(0, max_x)
|
||||||
|
ax.set_ylim(-1, 1)
|
||||||
|
|
||||||
# Setup midlines
|
# Setup midlines (depends on max_x and wave_data)
|
||||||
midline_color = cfg.midline_color
|
midline_color = cfg.midline_color
|
||||||
midline_width = pixels(1)
|
midline_width = pixels(1)
|
||||||
|
|
||||||
# zorder=-100 still draws on top of gridlines :(
|
# zorder=-100 still draws on top of gridlines :(
|
||||||
kw = dict(color=midline_color, linewidth=midline_width)
|
kw = dict(color=midline_color, linewidth=midline_width)
|
||||||
if cfg.v_midline:
|
if cfg.v_midline:
|
||||||
ax.axvline(x=max_x / 2, **kw)
|
ax.axvline(x=max_x / 2, **kw)
|
||||||
if cfg.h_midline:
|
if cfg.h_midline:
|
||||||
ax.axhline(y=0, **kw)
|
ax.axhline(y=0, **kw)
|
||||||
|
|
||||||
self._save_background()
|
self._save_background()
|
||||||
|
|
||||||
# Plot lines over background
|
# Plot lines over background
|
||||||
line_width = pixels(cfg.line_width)
|
line_width = pixels(cfg.line_width)
|
||||||
self._lines = []
|
|
||||||
|
|
||||||
for idx, data in enumerate(datas):
|
# Foreach wave
|
||||||
ax = self._axes[idx]
|
for wave_idx, wave_data in enumerate(datas):
|
||||||
line_color = self._line_params[idx].color
|
wave_axes = self._axes2d[wave_idx]
|
||||||
line = ax.plot(data, color=line_color, linewidth=line_width)[0]
|
wave_lines = []
|
||||||
self._lines.append(line)
|
|
||||||
|
# Foreach chan
|
||||||
|
for chan_idx, chan_data in enumerate(wave_data.T):
|
||||||
|
ax = wave_axes[chan_idx]
|
||||||
|
line_color = self._line_params[wave_idx].color
|
||||||
|
chan_line: Line2D = ax.plot(
|
||||||
|
chan_data, color=line_color, linewidth=line_width
|
||||||
|
)[0]
|
||||||
|
wave_lines.append(chan_line)
|
||||||
|
|
||||||
|
self._lines2d.append(wave_lines)
|
||||||
|
self._lines_flat.extend(wave_lines)
|
||||||
|
|
||||||
# Draw waveform data
|
# Draw waveform data
|
||||||
else:
|
else:
|
||||||
for idx, data in enumerate(datas):
|
# Foreach wave
|
||||||
line = self._lines[idx]
|
for wave_idx, wave_data in enumerate(datas):
|
||||||
line.set_ydata(data)
|
wave_lines = self._lines2d[wave_idx]
|
||||||
|
|
||||||
|
# Foreach chan
|
||||||
|
for chan_idx, chan_data in enumerate(wave_data.T):
|
||||||
|
chan_line = wave_lines[chan_idx]
|
||||||
|
chan_line.set_ydata(chan_data)
|
||||||
|
|
||||||
self._redraw_over_background()
|
self._redraw_over_background()
|
||||||
|
|
||||||
|
@ -314,8 +370,7 @@ class MatplotlibRenderer(Renderer):
|
||||||
canvas: FigureCanvasAgg = self._fig.canvas
|
canvas: FigureCanvasAgg = self._fig.canvas
|
||||||
canvas.restore_region(self.bg_cache)
|
canvas.restore_region(self.bg_cache)
|
||||||
|
|
||||||
assert self._lines is not None
|
for line in self._lines_flat:
|
||||||
for line in self._lines:
|
|
||||||
line.axes.draw_artist(line)
|
line.axes.draw_artist(line)
|
||||||
|
|
||||||
# https://bastibe.de/2013-05-30-speeding-up-matplotlib.html
|
# https://bastibe.de/2013-05-30-speeding-up-matplotlib.html
|
||||||
|
|
|
@ -40,7 +40,7 @@ class Wave:
|
||||||
__slots__ = """
|
__slots__ = """
|
||||||
wave_path
|
wave_path
|
||||||
amplification
|
amplification
|
||||||
smp_s data _flatten is_mono
|
smp_s data return_channels _flatten is_mono
|
||||||
nsamp dtype
|
nsamp dtype
|
||||||
center max_val
|
center max_val
|
||||||
""".split()
|
""".split()
|
||||||
|
@ -91,6 +91,7 @@ class Wave:
|
||||||
assert self.data.ndim in [1, 2]
|
assert self.data.ndim in [1, 2]
|
||||||
self.is_mono = self.data.ndim == 1
|
self.is_mono = self.data.ndim == 1
|
||||||
self.flatten = flatten
|
self.flatten = flatten
|
||||||
|
self.return_channels = False
|
||||||
|
|
||||||
# Cast self.data to stereo (nsamp, nchan)
|
# Cast self.data to stereo (nsamp, nchan)
|
||||||
if self.is_mono:
|
if self.is_mono:
|
||||||
|
@ -130,9 +131,10 @@ class Wave:
|
||||||
else:
|
else:
|
||||||
raise CorrError(f"unexpected wavfile dtype {dtype}")
|
raise CorrError(f"unexpected wavfile dtype {dtype}")
|
||||||
|
|
||||||
def with_flatten(self, flatten: Flatten) -> "Wave":
|
def with_flatten(self, flatten: Flatten, return_channels: bool) -> "Wave":
|
||||||
new = copy.copy(self)
|
new = copy.copy(self)
|
||||||
new.flatten = flatten
|
new.flatten = flatten
|
||||||
|
new.return_channels = return_channels
|
||||||
return new
|
return new
|
||||||
|
|
||||||
def __getitem__(self, index: Union[int, slice]) -> np.ndarray:
|
def __getitem__(self, index: Union[int, slice]) -> np.ndarray:
|
||||||
|
@ -154,6 +156,9 @@ class Wave:
|
||||||
|
|
||||||
data -= self.center
|
data -= self.center
|
||||||
data *= self.amplification / self.max_val
|
data *= self.amplification / self.max_val
|
||||||
|
|
||||||
|
if self.return_channels and len(data.shape) == 1:
|
||||||
|
data = data.reshape(-1, 1)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _get(self, begin: int, end: int, subsampling: int) -> np.ndarray:
|
def _get(self, begin: int, end: int, subsampling: int) -> np.ndarray:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""
|
||||||
Integration tests found in:
|
Integration tests found in:
|
||||||
- test_cli.py
|
- test_cli.py
|
||||||
|
- test_renderer.py
|
||||||
- test_output.py
|
- test_output.py
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ from corrscope.channel import ChannelConfig, Channel
|
||||||
from corrscope.corrscope import default_config, CorrScope, BenchmarkMode, Arguments
|
from corrscope.corrscope import default_config, CorrScope, BenchmarkMode, Arguments
|
||||||
from corrscope.triggers import NullTriggerConfig
|
from corrscope.triggers import NullTriggerConfig
|
||||||
from corrscope.util import coalesce
|
from corrscope.util import coalesce
|
||||||
|
from corrscope.wave import Flatten
|
||||||
|
|
||||||
|
|
||||||
positive = hs.integers(min_value=1, max_value=100)
|
positive = hs.integers(min_value=1, max_value=100)
|
||||||
|
@ -134,3 +135,34 @@ def test_config_channel_width_stride(
|
||||||
|
|
||||||
|
|
||||||
# line_color is tested in test_renderer.py
|
# line_color is tested in test_renderer.py
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("filename", ["tests/sine440.wav", "tests/stereo in-phase.wav"])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("global_stereo", "chan_stereo"),
|
||||||
|
[
|
||||||
|
[Flatten.SumAvg, None],
|
||||||
|
[Flatten.Stereo, None],
|
||||||
|
[Flatten.SumAvg, Flatten.Stereo],
|
||||||
|
[Flatten.Stereo, Flatten.SumAvg],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_per_channel_stereo(
|
||||||
|
filename: str, global_stereo: Flatten, chan_stereo: Optional[Flatten]
|
||||||
|
):
|
||||||
|
"""Ensure you can enable/disable stereo on a per-channel basis."""
|
||||||
|
stereo = coalesce(chan_stereo, global_stereo)
|
||||||
|
|
||||||
|
# Test render wave.
|
||||||
|
cfg = default_config(render_stereo=global_stereo)
|
||||||
|
ccfg = ChannelConfig("tests/stereo in-phase.wav", render_stereo=chan_stereo)
|
||||||
|
channel = Channel(ccfg, cfg)
|
||||||
|
|
||||||
|
# Render wave *must* return stereo.
|
||||||
|
assert channel.render_wave[0:1].ndim == 2
|
||||||
|
data = channel.render_wave.get_around(0, return_nsamp=4, stride=1)
|
||||||
|
assert data.ndim == 2
|
||||||
|
|
||||||
|
if "stereo" in filename:
|
||||||
|
assert channel.render_wave._flatten == stereo
|
||||||
|
assert data.shape[1] == (2 if stereo is Flatten.Stereo else 1)
|
||||||
|
|
|
@ -1,9 +1,22 @@
|
||||||
import numpy as np
|
from typing import List
|
||||||
import pytest
|
|
||||||
|
|
||||||
from corrscope.layout import LayoutConfig, RendererLayout, EdgeFinder
|
import hypothesis.strategies as hs
|
||||||
|
import numpy as np
|
||||||
|
import numpy.testing as npt
|
||||||
|
import pytest
|
||||||
|
from hypothesis import given, settings
|
||||||
|
|
||||||
|
from corrscope.layout import (
|
||||||
|
LayoutConfig,
|
||||||
|
RendererLayout,
|
||||||
|
RegionSpec,
|
||||||
|
Edges,
|
||||||
|
Orientation,
|
||||||
|
StereoOrientation,
|
||||||
|
)
|
||||||
from corrscope.renderer import RendererConfig, MatplotlibRenderer
|
from corrscope.renderer import RendererConfig, MatplotlibRenderer
|
||||||
from tests.test_renderer import WIDTH, HEIGHT
|
from corrscope.util import ceildiv
|
||||||
|
from tests.test_renderer import WIDTH, HEIGHT, RENDER_Y_ZEROS
|
||||||
|
|
||||||
|
|
||||||
def test_layout_config():
|
def test_layout_config():
|
||||||
|
@ -22,44 +35,184 @@ def test_layout_config():
|
||||||
assert default.orientation == "h"
|
assert default.orientation == "h"
|
||||||
|
|
||||||
|
|
||||||
|
# Small range to ensure many collisions.
|
||||||
|
# max_value = 3 to allow for edge-free space in center.
|
||||||
|
integers = hs.integers(-1, 3)
|
||||||
|
|
||||||
|
|
||||||
|
@given(nrows=integers, ncols=integers, row=integers, col=integers)
|
||||||
|
@settings(max_examples=500)
|
||||||
|
def test_edges(nrows: int, ncols: int, row: int, col: int):
|
||||||
|
if not (nrows > 0 and ncols > 0 and 0 <= row < nrows and 0 <= col < ncols):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
edges = Edges.at(nrows, ncols, row, col)
|
||||||
|
return
|
||||||
|
|
||||||
|
edges = Edges.at(nrows, ncols, row, col)
|
||||||
|
assert bool(edges & Edges.Left) == (col == 0)
|
||||||
|
assert bool(edges & Edges.Right) == (col == ncols - 1)
|
||||||
|
assert bool(edges & Edges.Top) == (row == 0)
|
||||||
|
assert bool(edges & Edges.Bottom) == (row == nrows - 1)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("lcfg", [LayoutConfig(ncols=2), LayoutConfig(nrows=8)])
|
@pytest.mark.parametrize("lcfg", [LayoutConfig(ncols=2), LayoutConfig(nrows=8)])
|
||||||
@pytest.mark.parametrize("region_type", [str, tuple, list])
|
def test_hlayout(lcfg):
|
||||||
def test_hlayout(lcfg, region_type):
|
|
||||||
nplots = 15
|
nplots = 15
|
||||||
layout = RendererLayout(lcfg, nplots)
|
layout = RendererLayout(lcfg, [1] * nplots)
|
||||||
|
|
||||||
assert layout.ncols == 2
|
assert layout.wave_ncol == 2
|
||||||
assert layout.nrows == 8
|
assert layout.wave_nrow == 8
|
||||||
|
|
||||||
regions = layout.arrange(lambda row, col: region_type((row, col)))
|
region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg)
|
||||||
assert len(regions) == nplots
|
assert len(region2d) == nplots
|
||||||
|
for i, regions in enumerate(region2d):
|
||||||
|
assert len(regions) == 1, (i, len(regions))
|
||||||
|
|
||||||
|
np.testing.assert_equal(region2d[0][0].pos, (0, 0))
|
||||||
|
np.testing.assert_equal(region2d[1][0].pos, (0, 1))
|
||||||
|
np.testing.assert_equal(region2d[2][0].pos, (1, 0))
|
||||||
|
|
||||||
assert regions[0] == region_type((0, 0))
|
|
||||||
assert regions[1] == region_type((0, 1))
|
|
||||||
assert regions[2] == region_type((1, 0))
|
|
||||||
m = nplots - 1
|
m = nplots - 1
|
||||||
assert regions[m] == region_type((m // 2, m % 2))
|
npt.assert_equal(region2d[m][0].pos, (m // 2, m % 2))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"lcfg",
|
"lcfg",
|
||||||
[LayoutConfig(ncols=3, orientation="v"), LayoutConfig(nrows=3, orientation="v")],
|
[LayoutConfig(ncols=3, orientation="v"), LayoutConfig(nrows=3, orientation="v")],
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("region_type", [str, tuple, list])
|
def test_vlayout(lcfg):
|
||||||
def test_vlayout(lcfg, region_type):
|
|
||||||
nplots = 7
|
nplots = 7
|
||||||
layout = RendererLayout(lcfg, nplots)
|
layout = RendererLayout(lcfg, [1] * nplots)
|
||||||
|
|
||||||
assert layout.ncols == 3
|
assert layout.wave_ncol == 3
|
||||||
assert layout.nrows == 3
|
assert layout.wave_nrow == 3
|
||||||
|
|
||||||
regions = layout.arrange(lambda row, col: region_type((row, col)))
|
region2d: List[List[RegionSpec]] = layout.arrange(lambda arg: arg)
|
||||||
assert len(regions) == nplots
|
assert len(region2d) == nplots
|
||||||
|
for i, regions in enumerate(region2d):
|
||||||
|
assert len(regions) == 1, (i, len(regions))
|
||||||
|
|
||||||
assert regions[0] == region_type((0, 0))
|
np.testing.assert_equal(region2d[0][0].pos, (0, 0))
|
||||||
assert regions[2] == region_type((2, 0))
|
np.testing.assert_equal(region2d[2][0].pos, (2, 0))
|
||||||
assert regions[3] == region_type((0, 1))
|
np.testing.assert_equal(region2d[3][0].pos, (0, 1))
|
||||||
assert regions[6] == region_type((0, 2))
|
np.testing.assert_equal(region2d[6][0].pos, (0, 2))
|
||||||
|
|
||||||
|
|
||||||
|
@given(
|
||||||
|
wave_nchans=hs.lists(hs.integers(1, 10), min_size=1, max_size=100),
|
||||||
|
orientation=hs.sampled_from(Orientation),
|
||||||
|
stereo_orientation=hs.sampled_from(StereoOrientation),
|
||||||
|
nrow_ncol=hs.integers(1, 100),
|
||||||
|
is_nrows=hs.booleans(),
|
||||||
|
)
|
||||||
|
def test_stereo_layout(
|
||||||
|
orientation: Orientation,
|
||||||
|
stereo_orientation: StereoOrientation,
|
||||||
|
wave_nchans: List[int],
|
||||||
|
nrow_ncol: int,
|
||||||
|
is_nrows: bool,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Not-entirely-rigorous test for layout computation.
|
||||||
|
Mind-numbingly boring to write (and read?).
|
||||||
|
|
||||||
|
Honestly I prefer a good naming scheme in RendererLayout.arrange()
|
||||||
|
over unit tests.
|
||||||
|
|
||||||
|
- This is a regression test...
|
||||||
|
- And an obstacle to refactoring or feature development.
|
||||||
|
"""
|
||||||
|
# region Setup
|
||||||
|
if is_nrows:
|
||||||
|
nrows = nrow_ncol
|
||||||
|
ncols = None
|
||||||
|
else:
|
||||||
|
nrows = None
|
||||||
|
ncols = nrow_ncol
|
||||||
|
|
||||||
|
lcfg = LayoutConfig(
|
||||||
|
orientation=orientation,
|
||||||
|
nrows=nrows,
|
||||||
|
ncols=ncols,
|
||||||
|
stereo_orientation=stereo_orientation,
|
||||||
|
)
|
||||||
|
nwaves = len(wave_nchans)
|
||||||
|
layout = RendererLayout(lcfg, wave_nchans)
|
||||||
|
# endregion
|
||||||
|
|
||||||
|
# Assert layout dimensions correct
|
||||||
|
assert layout.wave_ncol == ncols or ceildiv(nwaves, nrows)
|
||||||
|
assert layout.wave_nrow == nrows or ceildiv(nwaves, ncols)
|
||||||
|
|
||||||
|
region2d: List[List[RegionSpec]] = layout.arrange(lambda r_spec: r_spec)
|
||||||
|
|
||||||
|
# Loop through layout regions
|
||||||
|
assert len(region2d) == len(wave_nchans)
|
||||||
|
for wave_i, wave_chans in enumerate(region2d):
|
||||||
|
stereo_nchan = wave_nchans[wave_i]
|
||||||
|
assert len(wave_chans) == stereo_nchan
|
||||||
|
|
||||||
|
# Compute channel dims within wave.
|
||||||
|
if stereo_orientation == StereoOrientation.overlay:
|
||||||
|
chans_per_wave = [1, 1]
|
||||||
|
elif stereo_orientation == StereoOrientation.v: # pos[0]++
|
||||||
|
chans_per_wave = [stereo_nchan, 1]
|
||||||
|
else:
|
||||||
|
assert stereo_orientation == StereoOrientation.h # pos[1]++
|
||||||
|
chans_per_wave = [1, stereo_nchan]
|
||||||
|
|
||||||
|
# Sanity-check position of channel 0 relative to origin (wave grid).
|
||||||
|
assert (np.add.reduce(wave_chans[0].pos) != 0) == (wave_i != 0)
|
||||||
|
npt.assert_equal(wave_chans[0].pos % chans_per_wave, 0)
|
||||||
|
|
||||||
|
for chan_j, chan in enumerate(wave_chans):
|
||||||
|
# Assert 0 <= position < size.
|
||||||
|
assert chan.pos.shape == chan.size.shape == (2,)
|
||||||
|
assert (0 <= chan.pos).all()
|
||||||
|
assert (chan.pos < chan.size).all()
|
||||||
|
|
||||||
|
# Sanity-check position of chan relative to origin (wave grid).
|
||||||
|
npt.assert_equal(
|
||||||
|
chan.pos // chans_per_wave, wave_chans[0].pos // chans_per_wave
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check position of region (relative to channel 0)
|
||||||
|
chan_wave_pos = chan.pos - wave_chans[0].pos
|
||||||
|
|
||||||
|
if stereo_orientation == StereoOrientation.overlay:
|
||||||
|
npt.assert_equal(chan_wave_pos, [0, 0])
|
||||||
|
elif stereo_orientation == StereoOrientation.v: # pos[0]++
|
||||||
|
npt.assert_equal(chan_wave_pos, [chan_j, 0])
|
||||||
|
else:
|
||||||
|
assert stereo_orientation == StereoOrientation.h # pos[1]++
|
||||||
|
npt.assert_equal(chan_wave_pos, [0, chan_j])
|
||||||
|
|
||||||
|
# Check screen edges
|
||||||
|
screen_edges = chan.screen_edges
|
||||||
|
assert bool(screen_edges & Edges.Top) == (chan.row == 0)
|
||||||
|
assert bool(screen_edges & Edges.Left) == (chan.col == 0)
|
||||||
|
assert bool(screen_edges & Edges.Bottom) == (chan.row == chan.nrow - 1)
|
||||||
|
assert bool(screen_edges & Edges.Right) == (chan.col == chan.ncol - 1)
|
||||||
|
|
||||||
|
# Check stereo edges
|
||||||
|
wave_edges = chan.wave_edges
|
||||||
|
if stereo_orientation == StereoOrientation.overlay:
|
||||||
|
assert wave_edges == ~Edges.NONE
|
||||||
|
elif stereo_orientation == StereoOrientation.v: # pos[0]++
|
||||||
|
lr = Edges.Left | Edges.Right
|
||||||
|
assert wave_edges & lr == lr
|
||||||
|
assert bool(wave_edges & Edges.Top) == (chan.row % stereo_nchan == 0)
|
||||||
|
assert bool(wave_edges & Edges.Bottom) == (
|
||||||
|
(chan.row + 1) % stereo_nchan == 0
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert stereo_orientation == StereoOrientation.h # pos[1]++
|
||||||
|
tb = Edges.Top | Edges.Bottom
|
||||||
|
assert wave_edges & tb == tb
|
||||||
|
assert bool(wave_edges & Edges.Left) == (chan.col % stereo_nchan == 0)
|
||||||
|
assert bool(wave_edges & Edges.Right) == (
|
||||||
|
(chan.col + 1) % stereo_nchan == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_renderer_layout():
|
def test_renderer_layout():
|
||||||
|
@ -69,21 +222,9 @@ def test_renderer_layout():
|
||||||
nplots = 15
|
nplots = 15
|
||||||
|
|
||||||
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
|
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
|
||||||
|
r.render_frame([RENDER_Y_ZEROS] * nplots)
|
||||||
|
layout = r.layout
|
||||||
|
|
||||||
# 2 columns, 8 rows
|
# 2 columns, 8 rows
|
||||||
assert r.layout.ncols == 2
|
assert layout.wave_ncol == 2
|
||||||
assert r.layout.nrows == 8
|
assert layout.wave_nrow == 8
|
||||||
|
|
||||||
|
|
||||||
# Test EdgeFinder
|
|
||||||
|
|
||||||
|
|
||||||
def test_edge_finder():
|
|
||||||
regions2d = np.arange(24).reshape(3, 8)
|
|
||||||
edges = EdgeFinder(regions2d)
|
|
||||||
|
|
||||||
# Check borders
|
|
||||||
np.testing.assert_equal(edges.lefts, regions2d[:, 0])
|
|
||||||
np.testing.assert_equal(edges.rights, regions2d[:, -1])
|
|
||||||
np.testing.assert_equal(edges.tops, regions2d[0, :])
|
|
||||||
np.testing.assert_equal(edges.bottoms, regions2d[-1, :])
|
|
||||||
|
|
|
@ -23,11 +23,8 @@ from corrscope.outputs import (
|
||||||
Stop,
|
Stop,
|
||||||
)
|
)
|
||||||
from corrscope.renderer import RendererConfig, MatplotlibRenderer
|
from corrscope.renderer import RendererConfig, MatplotlibRenderer
|
||||||
from corrscope.settings.paths import MissingFFmpegError
|
from tests.test_renderer import RENDER_Y_ZEROS, WIDTH, HEIGHT
|
||||||
from tests.test_renderer import ALL_ZEROS
|
|
||||||
|
|
||||||
WIDTH = 192
|
|
||||||
HEIGHT = 108
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import pytest_mock
|
import pytest_mock
|
||||||
|
@ -90,7 +87,7 @@ def test_render_output():
|
||||||
renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1, channel_cfgs=None)
|
renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1, channel_cfgs=None)
|
||||||
out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG)
|
out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG)
|
||||||
|
|
||||||
renderer.render_frame([ALL_ZEROS])
|
renderer.render_frame([RENDER_Y_ZEROS])
|
||||||
out.write_frame(renderer.get_frame())
|
out.write_frame(renderer.get_frame())
|
||||||
|
|
||||||
assert out.close() == 0
|
assert out.close() == 0
|
||||||
|
|
|
@ -1,38 +1,47 @@
|
||||||
from typing import Optional
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
import matplotlib.colors
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
from matplotlib.colors import to_rgb
|
|
||||||
|
|
||||||
from corrscope.channel import ChannelConfig
|
from corrscope.channel import ChannelConfig
|
||||||
|
from corrscope.corrscope import CorrScope, default_config, Arguments
|
||||||
from corrscope.layout import LayoutConfig
|
from corrscope.layout import LayoutConfig
|
||||||
from corrscope.outputs import RGB_DEPTH
|
from corrscope.outputs import RGB_DEPTH, FFplayOutputConfig
|
||||||
from corrscope.renderer import RendererConfig, MatplotlibRenderer
|
from corrscope.renderer import RendererConfig, MatplotlibRenderer
|
||||||
|
from corrscope.wave import Flatten
|
||||||
|
|
||||||
WIDTH = 640
|
if TYPE_CHECKING:
|
||||||
HEIGHT = 360
|
import pytest_mock
|
||||||
|
|
||||||
ALL_ZEROS = np.array([0, 0])
|
WIDTH = 64
|
||||||
|
HEIGHT = 64
|
||||||
|
|
||||||
|
RENDER_Y_ZEROS = np.zeros((2, 1))
|
||||||
|
RENDER_Y_STEREO = np.zeros((2, 2))
|
||||||
|
OPACITY = 2 / 3
|
||||||
|
|
||||||
all_colors = pytest.mark.parametrize(
|
all_colors = pytest.mark.parametrize(
|
||||||
"bg_str,fg_str,grid_str",
|
"bg_str,fg_str,grid_str,data",
|
||||||
[
|
[
|
||||||
("#000000", "#ffffff", None),
|
("#000000", "#ffffff", None, RENDER_Y_ZEROS),
|
||||||
("#ffffff", "#000000", None),
|
("#ffffff", "#000000", None, RENDER_Y_ZEROS),
|
||||||
("#0000aa", "#aaaa00", None),
|
("#0000aa", "#aaaa00", None, RENDER_Y_ZEROS),
|
||||||
("#aaaa00", "#0000aa", None),
|
("#aaaa00", "#0000aa", None, RENDER_Y_ZEROS),
|
||||||
# Enabling gridlines enables Axes rectangles.
|
# Enabling ~~beautiful magenta~~ gridlines enables Axes rectangles.
|
||||||
# Make sure they don't draw *over* the global figure background.
|
# Make sure bg is disabled, so they don't overwrite global figure background.
|
||||||
("#0000aa", "#aaaa00", "#ff00ff"), # beautiful magenta gridlines
|
("#0000aa", "#aaaa00", "#ff00ff", RENDER_Y_ZEROS),
|
||||||
("#aaaa00", "#0000aa", "#ff00ff"),
|
("#aaaa00", "#0000aa", "#ff00ff", RENDER_Y_ZEROS),
|
||||||
|
("#0000aa", "#aaaa00", "#ff00ff", RENDER_Y_STEREO),
|
||||||
|
("#aaaa00", "#0000aa", "#ff00ff", RENDER_Y_STEREO),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
nplots = 2
|
NPLOTS = 2
|
||||||
|
|
||||||
|
|
||||||
@all_colors
|
@all_colors
|
||||||
def test_default_colors(bg_str, fg_str, grid_str):
|
def test_default_colors(bg_str, fg_str, grid_str, data):
|
||||||
""" Test the default background/foreground colors. """
|
""" Test the default background/foreground colors. """
|
||||||
cfg = RendererConfig(
|
cfg = RendererConfig(
|
||||||
WIDTH,
|
WIDTH,
|
||||||
|
@ -40,23 +49,24 @@ def test_default_colors(bg_str, fg_str, grid_str):
|
||||||
bg_color=bg_str,
|
bg_color=bg_str,
|
||||||
init_line_color=fg_str,
|
init_line_color=fg_str,
|
||||||
grid_color=grid_str,
|
grid_color=grid_str,
|
||||||
|
stereo_grid_opacity=OPACITY,
|
||||||
line_width=2.0,
|
line_width=2.0,
|
||||||
antialiasing=False,
|
antialiasing=False,
|
||||||
)
|
)
|
||||||
lcfg = LayoutConfig()
|
lcfg = LayoutConfig()
|
||||||
|
|
||||||
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
|
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, None)
|
||||||
verify(r, bg_str, fg_str, grid_str)
|
verify(r, bg_str, fg_str, grid_str, data)
|
||||||
|
|
||||||
# Ensure default ChannelConfig(line_color=None) does not override line color
|
# Ensure default ChannelConfig(line_color=None) does not override line color
|
||||||
chan = ChannelConfig(wav_path="")
|
chan = ChannelConfig(wav_path="")
|
||||||
channels = [chan] * nplots
|
channels = [chan] * NPLOTS
|
||||||
r = MatplotlibRenderer(cfg, lcfg, nplots, channels)
|
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
|
||||||
verify(r, bg_str, fg_str, grid_str)
|
verify(r, bg_str, fg_str, grid_str, data)
|
||||||
|
|
||||||
|
|
||||||
@all_colors
|
@all_colors
|
||||||
def test_line_colors(bg_str, fg_str, grid_str):
|
def test_line_colors(bg_str, fg_str, grid_str, data):
|
||||||
""" Test channel-specific line color overrides """
|
""" Test channel-specific line color overrides """
|
||||||
cfg = RendererConfig(
|
cfg = RendererConfig(
|
||||||
WIDTH,
|
WIDTH,
|
||||||
|
@ -64,30 +74,44 @@ def test_line_colors(bg_str, fg_str, grid_str):
|
||||||
bg_color=bg_str,
|
bg_color=bg_str,
|
||||||
init_line_color="#888888",
|
init_line_color="#888888",
|
||||||
grid_color=grid_str,
|
grid_color=grid_str,
|
||||||
|
stereo_grid_opacity=OPACITY,
|
||||||
line_width=2.0,
|
line_width=2.0,
|
||||||
antialiasing=False,
|
antialiasing=False,
|
||||||
)
|
)
|
||||||
lcfg = LayoutConfig()
|
lcfg = LayoutConfig()
|
||||||
|
|
||||||
chan = ChannelConfig(wav_path="", line_color=fg_str)
|
chan = ChannelConfig(wav_path="", line_color=fg_str)
|
||||||
channels = [chan] * nplots
|
channels = [chan] * NPLOTS
|
||||||
r = MatplotlibRenderer(cfg, lcfg, nplots, channels)
|
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
|
||||||
verify(r, bg_str, fg_str, grid_str)
|
verify(r, bg_str, fg_str, grid_str, data)
|
||||||
|
|
||||||
|
|
||||||
def verify(r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str]):
|
TOLERANCE = 3
|
||||||
r.render_frame([ALL_ZEROS] * nplots)
|
|
||||||
|
|
||||||
|
def verify(
|
||||||
|
r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str], data: np.ndarray
|
||||||
|
):
|
||||||
|
r.render_frame([data] * NPLOTS)
|
||||||
frame_colors: np.ndarray = np.frombuffer(r.get_frame(), dtype=np.uint8).reshape(
|
frame_colors: np.ndarray = np.frombuffer(r.get_frame(), dtype=np.uint8).reshape(
|
||||||
(-1, RGB_DEPTH)
|
(-1, RGB_DEPTH)
|
||||||
)
|
)
|
||||||
|
|
||||||
bg_u8 = [round(c * 255) for c in to_rgb(bg_str)]
|
bg_u8 = to_rgb(bg_str)
|
||||||
fg_u8 = [round(c * 255) for c in to_rgb(fg_str)]
|
fg_u8 = to_rgb(fg_str)
|
||||||
all_colors = [bg_u8, fg_u8]
|
all_colors = [bg_u8, fg_u8]
|
||||||
|
|
||||||
if grid_str:
|
if grid_str:
|
||||||
grid_u8 = [round(c * 255) for c in to_rgb(grid_str)]
|
grid_u8 = to_rgb(grid_str)
|
||||||
all_colors.append(grid_u8)
|
all_colors.append(grid_u8)
|
||||||
|
else:
|
||||||
|
grid_u8 = bg_u8
|
||||||
|
|
||||||
|
assert (data.shape[1] > 1) == (data is RENDER_Y_STEREO)
|
||||||
|
is_stereo = data.shape[1] > 1
|
||||||
|
if is_stereo:
|
||||||
|
stereo_grid_u8 = (grid_u8 * OPACITY + bg_u8 * (1 - OPACITY)).astype(int)
|
||||||
|
all_colors.append(stereo_grid_u8)
|
||||||
|
|
||||||
# Ensure background is correct
|
# Ensure background is correct
|
||||||
bg_frame = frame_colors[0]
|
bg_frame = frame_colors[0]
|
||||||
|
@ -104,5 +128,36 @@ def verify(r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str]):
|
||||||
if grid_str:
|
if grid_str:
|
||||||
assert np.prod(frame_colors == grid_u8, axis=-1).any(), "Missing grid_str"
|
assert np.prod(frame_colors == grid_u8, axis=-1).any(), "Missing grid_str"
|
||||||
|
|
||||||
|
# Ensure stereo grid color is present
|
||||||
|
if is_stereo:
|
||||||
|
assert (
|
||||||
|
np.min(np.sum(np.abs(frame_colors - stereo_grid_u8), axis=-1)) < TOLERANCE
|
||||||
|
), "Missing stereo gridlines"
|
||||||
|
|
||||||
assert (np.amax(frame_colors, axis=0) == np.amax(all_colors, axis=0)).all()
|
assert (np.amax(frame_colors, axis=0) == np.amax(all_colors, axis=0)).all()
|
||||||
assert (np.amin(frame_colors, axis=0) == np.amin(all_colors, axis=0)).all()
|
assert (np.amin(frame_colors, axis=0) == np.amin(all_colors, axis=0)).all()
|
||||||
|
|
||||||
|
|
||||||
|
def to_rgb(c) -> np.ndarray:
|
||||||
|
to_rgb = matplotlib.colors.to_rgb
|
||||||
|
return np.array([round(c * 255) for c in to_rgb(c)], dtype=int)
|
||||||
|
|
||||||
|
|
||||||
|
# Stereo *renderer* integration tests.
|
||||||
|
def test_stereo_render_integration(mocker: "pytest_mock.MockFixture"):
|
||||||
|
"""Ensure corrscope plays/renders in stereo, without crashing."""
|
||||||
|
|
||||||
|
# Stub out FFplay output.
|
||||||
|
mocker.patch.object(FFplayOutputConfig, "cls")
|
||||||
|
|
||||||
|
# Render in stereo.
|
||||||
|
cfg = default_config(
|
||||||
|
channels=[ChannelConfig("tests/stereo in-phase.wav")],
|
||||||
|
render_stereo=Flatten.Stereo,
|
||||||
|
end_time=0.5, # Reduce test duration
|
||||||
|
render=RendererConfig(WIDTH, HEIGHT),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Make sure it doesn't crash.
|
||||||
|
corr = CorrScope(cfg, Arguments(".", [FFplayOutputConfig()]))
|
||||||
|
corr.play()
|
||||||
|
|
|
@ -79,6 +79,7 @@ AllFlattens = Flatten.__members__.values()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("flatten", AllFlattens)
|
@pytest.mark.parametrize("flatten", AllFlattens)
|
||||||
|
@pytest.mark.parametrize("return_channels", [False, True])
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"path,nchan,peaks",
|
"path,nchan,peaks",
|
||||||
[
|
[
|
||||||
|
@ -88,19 +89,30 @@ AllFlattens = Flatten.__members__.values()
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_stereo_flatten_modes(
|
def test_stereo_flatten_modes(
|
||||||
flatten: Flatten, path: str, nchan: int, peaks: Sequence[float]
|
flatten: Flatten,
|
||||||
|
return_channels: bool,
|
||||||
|
path: str,
|
||||||
|
nchan: int,
|
||||||
|
peaks: Sequence[float],
|
||||||
):
|
):
|
||||||
"""Ensures all Flatten modes are handled properly
|
"""Ensures all Flatten modes are handled properly
|
||||||
for stereo and mono signals."""
|
for stereo and mono signals."""
|
||||||
|
|
||||||
|
# return_channels=False <-> triggering.
|
||||||
|
# flatten=stereo -> rendering.
|
||||||
|
# These conditions do not currently coexist.
|
||||||
|
# if not return_channels and flatten == Flatten.Stereo:
|
||||||
|
# return
|
||||||
|
|
||||||
assert nchan == len(peaks)
|
assert nchan == len(peaks)
|
||||||
wave = Wave(path)
|
wave = Wave(path)
|
||||||
|
|
||||||
if flatten not in Flatten.modes:
|
if flatten not in Flatten.modes:
|
||||||
with pytest.raises(CorrError):
|
with pytest.raises(CorrError):
|
||||||
wave.with_flatten(flatten)
|
wave.with_flatten(flatten, return_channels)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
wave = wave.with_flatten(flatten)
|
wave = wave.with_flatten(flatten, return_channels)
|
||||||
|
|
||||||
nsamp = wave.nsamp
|
nsamp = wave.nsamp
|
||||||
data = wave[:]
|
data = wave[:]
|
||||||
|
@ -111,7 +123,10 @@ def test_stereo_flatten_modes(
|
||||||
for chan_data, peak in zip(data.T, peaks):
|
for chan_data, peak in zip(data.T, peaks):
|
||||||
assert_full_scale(chan_data, peak)
|
assert_full_scale(chan_data, peak)
|
||||||
else:
|
else:
|
||||||
assert data.shape == (nsamp,)
|
if return_channels:
|
||||||
|
assert data.shape == (nsamp, 1)
|
||||||
|
else:
|
||||||
|
assert data.shape == (nsamp,)
|
||||||
|
|
||||||
# If DiffAvg and in-phase, L-R=0.
|
# If DiffAvg and in-phase, L-R=0.
|
||||||
if flatten == Flatten.DiffAvg:
|
if flatten == Flatten.DiffAvg:
|
||||||
|
|
Ładowanie…
Reference in New Issue