kopia lustrzana https://github.com/corrscope/corrscope
Refactor renderer API
rodzic
2fc908dd5a
commit
3849f7fe85
|
@ -214,8 +214,9 @@ class CorrScope:
|
|||
yield
|
||||
|
||||
def _load_renderer(self) -> Renderer:
|
||||
dummy_datas = [channel.get_render_around(0) for channel in self.channels]
|
||||
renderer = MatplotlibRenderer(
|
||||
self.cfg.render, self.cfg.layout, self.nchan, self.cfg.channels
|
||||
self.cfg.render, self.cfg.layout, dummy_datas, self.cfg.channels
|
||||
)
|
||||
return renderer
|
||||
|
||||
|
@ -293,7 +294,7 @@ class CorrScope:
|
|||
|
||||
if not_benchmarking or benchmark_mode >= BenchmarkMode.RENDER:
|
||||
# Render frame
|
||||
renderer.render_frame(render_datas)
|
||||
renderer.update_main_lines(render_datas)
|
||||
frame_data = renderer.get_frame()
|
||||
|
||||
if not_benchmarking or benchmark_mode == BenchmarkMode.OUTPUT:
|
||||
|
|
|
@ -153,12 +153,16 @@ class RendererLayout:
|
|||
self.wave_nrow = nrows
|
||||
self.wave_ncol = ncols
|
||||
|
||||
def arrange(self, region_factory: RegionFactory[Region]) -> List[List[Region]]:
|
||||
def arrange(
|
||||
self, region_factory: RegionFactory[Region], **kwargs
|
||||
) -> List[List[Region]]:
|
||||
"""
|
||||
(row, column) are fed into region_factory in a row-major order [row][col].
|
||||
Stereo channel pairs are extracted.
|
||||
The results are possibly reshaped into column-major order [col][row].
|
||||
|
||||
**kwargs -> region_factory(**kwargs).
|
||||
|
||||
:return arr[wave][channel] = Region
|
||||
"""
|
||||
|
||||
|
@ -229,7 +233,7 @@ class RendererLayout:
|
|||
screen_edges,
|
||||
wave_edges,
|
||||
)
|
||||
region = region_factory(chan_spec)
|
||||
region = region_factory(chan_spec, **kwargs)
|
||||
region_chan.append(region)
|
||||
|
||||
# Move to next channel position
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List, TYPE_CHECKING, Any
|
||||
from typing import Optional, List, TYPE_CHECKING, Any, Callable
|
||||
|
||||
import attr
|
||||
import matplotlib
|
||||
import matplotlib # do NOT import anything else until we call matplotlib.use().
|
||||
import matplotlib.colors
|
||||
import numpy as np
|
||||
|
||||
|
@ -43,6 +43,7 @@ from matplotlib.backends.backend_agg import FigureCanvasAgg
|
|||
from matplotlib.figure import Figure
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.artist import Artist
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.lines import Line2D
|
||||
from corrscope.channel import ChannelConfig
|
||||
|
@ -99,18 +100,24 @@ class LineParam:
|
|||
color: str
|
||||
|
||||
|
||||
UpdateLines = Callable[[List[np.ndarray]], None]
|
||||
|
||||
# TODO rename to Plotter
|
||||
class Renderer(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
cfg: RendererConfig,
|
||||
lcfg: "LayoutConfig",
|
||||
nplots: int,
|
||||
dummy_datas: List[np.ndarray],
|
||||
channel_cfgs: Optional[List["ChannelConfig"]],
|
||||
):
|
||||
self.cfg = cfg
|
||||
self.lcfg = lcfg
|
||||
self.nplots = nplots
|
||||
self.nplots = len(dummy_datas)
|
||||
|
||||
assert len(dummy_datas[0].shape) == 2, dummy_datas[0].shape
|
||||
self.wave_nsamps = [data.shape[0] for data in dummy_datas]
|
||||
self.wave_nchans = [data.shape[1] for data in dummy_datas]
|
||||
|
||||
# Load line colors.
|
||||
if channel_cfgs is not None:
|
||||
|
@ -127,8 +134,16 @@ class Renderer(ABC):
|
|||
for color in line_colors
|
||||
]
|
||||
|
||||
_update_main_lines: Optional[UpdateLines] = None
|
||||
|
||||
def update_main_lines(self, datas: List[np.ndarray]) -> None:
|
||||
if self._update_main_lines is None:
|
||||
self._update_main_lines = self.add_lines(datas)
|
||||
|
||||
self._update_main_lines(datas)
|
||||
|
||||
@abstractmethod
|
||||
def render_frame(self, datas: List[np.ndarray]) -> None:
|
||||
def add_lines(self, dummy_datas: List[np.ndarray]) -> UpdateLines:
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
|
@ -175,217 +190,221 @@ class MatplotlibRenderer(Renderer):
|
|||
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
|
||||
)
|
||||
|
||||
self._fig: "Figure"
|
||||
self._setup_axes(self.wave_nchans)
|
||||
|
||||
# _axes2d[wave][chan] = Axes
|
||||
self._axes2d: List[List["Axes"]] # set by set_layout()
|
||||
self._artists: List["Artist"] = []
|
||||
|
||||
# _lines2d[wave][chan] = Line2D
|
||||
self._lines2d: List[List[Line2D]] = []
|
||||
self._lines_flat: List["Line2D"] = []
|
||||
_fig: "Figure"
|
||||
|
||||
transparent = "#00000000"
|
||||
# _axes2d[wave][chan] = Axes
|
||||
# Primary, used to draw oscilloscope lines and gridlines.
|
||||
_axes2d: List[List["Axes"]] # set by set_layout()
|
||||
|
||||
layout: RendererLayout
|
||||
# _axes_mono[wave] = Axes
|
||||
# Secondary, used for titles and debug plots.
|
||||
_axes_mono: List["Axes"]
|
||||
|
||||
def _set_layout(self, wave_nchans: List[int]) -> None:
|
||||
def _setup_axes(self, wave_nchans: List[int]) -> None:
|
||||
"""
|
||||
Creates a flat array of Matplotlib Axes, with the new layout.
|
||||
Opens a window showing the Figure (and Axes).
|
||||
|
||||
Inputs: self.cfg, self.fig
|
||||
Outputs: self.nrows, self.ncols, self.axes
|
||||
Sets up each Axes with correct region limits.
|
||||
"""
|
||||
|
||||
self.layout = RendererLayout(self.lcfg, wave_nchans)
|
||||
self.layout_mono = RendererLayout(self.lcfg, [1] * self.nplots)
|
||||
|
||||
# Create Axes
|
||||
# https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html
|
||||
if hasattr(self, "_fig"):
|
||||
raise Exception("I don't currently expect to call _set_layout() twice")
|
||||
raise Exception("I don't currently expect to call _setup_axes() twice")
|
||||
# plt.close(self.fig)
|
||||
|
||||
grid_color = self.cfg.grid_color
|
||||
cfg = self.cfg
|
||||
|
||||
self._fig = Figure()
|
||||
FigureCanvasAgg(self._fig)
|
||||
|
||||
# RegionFactory
|
||||
def axes_factory(r: RegionSpec) -> "Axes":
|
||||
width = 1 / r.ncol
|
||||
left = r.col / r.ncol
|
||||
assert 0 <= left < 1
|
||||
|
||||
height = 1 / r.nrow
|
||||
bottom = (r.nrow - r.row - 1) / r.nrow
|
||||
assert 0 <= bottom < 1
|
||||
|
||||
# Disabling xticks/yticks is unnecessary, since we hide Axises.
|
||||
ax = self._fig.add_axes([left, bottom, width, height], xticks=[], yticks=[])
|
||||
|
||||
if grid_color:
|
||||
# Initialize borders
|
||||
# Hide Axises
|
||||
# (drawing them is very slow, and we disable ticks+labels anyway)
|
||||
ax.get_xaxis().set_visible(False)
|
||||
ax.get_yaxis().set_visible(False)
|
||||
|
||||
# Background color
|
||||
# ax.patch.set_fill(False) sets _fill=False,
|
||||
# then calls _set_facecolor(...) "alpha = self._alpha if self._fill else 0".
|
||||
# It is no faster than below.
|
||||
ax.set_facecolor(self.transparent)
|
||||
|
||||
# Set border colors
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color(grid_color)
|
||||
|
||||
def hide(key: str):
|
||||
ax.spines[key].set_visible(False)
|
||||
|
||||
# Hide all axes except bottom-right.
|
||||
hide("top")
|
||||
hide("left")
|
||||
|
||||
# If bottom of screen, hide bottom. If right of screen, hide right.
|
||||
if r.screen_edges & Edges.Bottom:
|
||||
hide("bottom")
|
||||
if r.screen_edges & Edges.Right:
|
||||
hide("right")
|
||||
|
||||
# Dim stereo gridlines
|
||||
if self.cfg.stereo_grid_opacity > 0:
|
||||
dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
|
||||
dim_color[-1] = self.cfg.stereo_grid_opacity
|
||||
|
||||
def dim(key: str):
|
||||
ax.spines[key].set_color(dim_color)
|
||||
|
||||
else:
|
||||
dim = hide
|
||||
|
||||
# If not bottom of wave, dim bottom. If not right of wave, dim right.
|
||||
if not r.wave_edges & Edges.Bottom:
|
||||
dim("bottom")
|
||||
if not r.wave_edges & Edges.Right:
|
||||
dim("right")
|
||||
|
||||
else:
|
||||
ax.set_axis_off()
|
||||
|
||||
return ax
|
||||
|
||||
# Generate arrangement (using self.lcfg, wave_nchans)
|
||||
# _axes2d[wave][chan] = Axes
|
||||
self._axes2d = self.layout.arrange(axes_factory)
|
||||
|
||||
# Setup figure geometry
|
||||
self._fig.set_dpi(DPI)
|
||||
self._fig.set_size_inches(self.cfg.width / DPI, self.cfg.height / DPI)
|
||||
FigureCanvasAgg(self._fig)
|
||||
|
||||
def render_frame(self, datas: List[np.ndarray]) -> None:
|
||||
# Setup background
|
||||
self._fig.set_facecolor(cfg.bg_color)
|
||||
|
||||
# Create Axes (using self.lcfg, wave_nchans)
|
||||
# _axes2d[wave][chan] = Axes
|
||||
self._axes2d = self.layout.arrange(self._axes_factory)
|
||||
|
||||
"""
|
||||
Adding an axes using the same arguments as a previous axes
|
||||
currently reuses the earlier instance.
|
||||
In a future version, a new instance will always be created and returned.
|
||||
Meanwhile, this warning can be suppressed, and the future behavior ensured,
|
||||
by passing a unique label to each axes instance.
|
||||
|
||||
ax=fig.add_axes(label=) is unused, even if you call ax.legend().
|
||||
"""
|
||||
# _axes_mono[wave] = Axes
|
||||
self._axes_mono = []
|
||||
# Returns 2D list of [self.nplots][1]Axes.
|
||||
axes_mono_2d = self.layout_mono.arrange(self._axes_factory, label="mono")
|
||||
for axes_list in axes_mono_2d:
|
||||
assert len(axes_list) == 1
|
||||
self._axes_mono.extend(axes_list)
|
||||
|
||||
# Setup axes
|
||||
for idx, N in enumerate(self.wave_nsamps):
|
||||
wave_axes = self._axes2d[idx]
|
||||
max_x = N - 1
|
||||
|
||||
def scale_axes(ax: "Axes"):
|
||||
ax.set_xlim(0, max_x)
|
||||
ax.set_ylim(-1, 1)
|
||||
|
||||
scale_axes(self._axes_mono[idx])
|
||||
for ax in unique_by_id(wave_axes):
|
||||
scale_axes(ax)
|
||||
|
||||
# Setup midlines (depends on max_x and wave_data)
|
||||
midline_color = cfg.midline_color
|
||||
midline_width = pixels(1)
|
||||
|
||||
# Not quite sure if midlines or gridlines draw on top
|
||||
kw = dict(color=midline_color, linewidth=midline_width)
|
||||
if cfg.v_midline:
|
||||
# See Wave.get_around() docstring.
|
||||
# wave_data[N//2] == self[sample], usually > 0.
|
||||
ax.axvline(x=N // 2 - 0.5, **kw)
|
||||
if cfg.h_midline:
|
||||
ax.axhline(y=0, **kw)
|
||||
|
||||
self._save_background()
|
||||
|
||||
transparent = "#00000000"
|
||||
|
||||
# satisfies RegionFactory
|
||||
def _axes_factory(self, r: RegionSpec, label: str = "") -> "Axes":
|
||||
grid_color = self.cfg.grid_color
|
||||
|
||||
width = 1 / r.ncol
|
||||
left = r.col / r.ncol
|
||||
assert 0 <= left < 1
|
||||
|
||||
height = 1 / r.nrow
|
||||
bottom = (r.nrow - r.row - 1) / r.nrow
|
||||
assert 0 <= bottom < 1
|
||||
|
||||
# Disabling xticks/yticks is unnecessary, since we hide Axises.
|
||||
ax = self._fig.add_axes(
|
||||
[left, bottom, width, height], xticks=[], yticks=[], label=label
|
||||
)
|
||||
|
||||
if grid_color:
|
||||
# Initialize borders
|
||||
# Hide Axises
|
||||
# (drawing them is very slow, and we disable ticks+labels anyway)
|
||||
ax.get_xaxis().set_visible(False)
|
||||
ax.get_yaxis().set_visible(False)
|
||||
|
||||
# Background color
|
||||
# ax.patch.set_fill(False) sets _fill=False,
|
||||
# then calls _set_facecolor(...) "alpha = self._alpha if self._fill else 0".
|
||||
# It is no faster than below.
|
||||
ax.set_facecolor(self.transparent)
|
||||
|
||||
# Set border colors
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color(grid_color)
|
||||
|
||||
def hide(key: str):
|
||||
ax.spines[key].set_visible(False)
|
||||
|
||||
# Hide all axes except bottom-right.
|
||||
hide("top")
|
||||
hide("left")
|
||||
|
||||
# If bottom of screen, hide bottom. If right of screen, hide right.
|
||||
if r.screen_edges & Edges.Bottom:
|
||||
hide("bottom")
|
||||
if r.screen_edges & Edges.Right:
|
||||
hide("right")
|
||||
|
||||
# Dim stereo gridlines
|
||||
if self.cfg.stereo_grid_opacity > 0:
|
||||
dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
|
||||
dim_color[-1] = self.cfg.stereo_grid_opacity
|
||||
|
||||
def dim(key: str):
|
||||
ax.spines[key].set_color(dim_color)
|
||||
|
||||
else:
|
||||
dim = hide
|
||||
|
||||
# If not bottom of wave, dim bottom. If not right of wave, dim right.
|
||||
if not r.wave_edges & Edges.Bottom:
|
||||
dim("bottom")
|
||||
if not r.wave_edges & Edges.Right:
|
||||
dim("right")
|
||||
|
||||
else:
|
||||
ax.set_axis_off()
|
||||
|
||||
return ax
|
||||
|
||||
# Public API
|
||||
def add_lines(self, dummy_datas: List[np.ndarray]) -> UpdateLines:
|
||||
cfg = self.cfg
|
||||
|
||||
# Plot lines over background
|
||||
line_width = pixels(cfg.line_width)
|
||||
|
||||
# Foreach wave, plot dummy data.
|
||||
lines2d = []
|
||||
for wave_idx, wave_data in enumerate(dummy_datas):
|
||||
wave_zeros = np.zeros_like(wave_data)
|
||||
|
||||
wave_axes = self._axes2d[wave_idx]
|
||||
wave_lines = []
|
||||
|
||||
# Foreach chan
|
||||
for chan_idx, chan_zeros in enumerate(wave_zeros.T):
|
||||
ax = wave_axes[chan_idx]
|
||||
line_color = self._line_params[wave_idx].color
|
||||
chan_line: Line2D = ax.plot(
|
||||
chan_zeros, color=line_color, linewidth=line_width
|
||||
)[0]
|
||||
wave_lines.append(chan_line)
|
||||
|
||||
lines2d.append(wave_lines)
|
||||
self._artists.extend(wave_lines)
|
||||
|
||||
return lambda datas: self._update_lines(lines2d, datas)
|
||||
|
||||
@staticmethod
|
||||
def _update_lines(lines2d: "List[List[Line2D]]", datas: List[np.ndarray]) -> None:
|
||||
"""
|
||||
Preconditions:
|
||||
- lines2d[wave][chan] = Line2D
|
||||
- datas[wave] = ndarray, [samp][chan] = FLOAT
|
||||
"""
|
||||
nplots = len(lines2d)
|
||||
ndata = len(datas)
|
||||
if self.nplots != ndata:
|
||||
if nplots != ndata:
|
||||
raise ValueError(
|
||||
f"incorrect data to plot: {self.nplots} plots but {ndata} datas"
|
||||
f"incorrect data to plot: {nplots} plots but {ndata} dummy_datas"
|
||||
)
|
||||
|
||||
# Initialize axes and draw waveform data
|
||||
if not self._lines2d:
|
||||
assert len(datas[0].shape) == 2, datas[0].shape
|
||||
|
||||
wave_nchans = [data.shape[1] for data in datas]
|
||||
self._set_layout(wave_nchans)
|
||||
|
||||
cfg = self.cfg
|
||||
|
||||
# Setup background/axes
|
||||
self._fig.set_facecolor(cfg.bg_color)
|
||||
for idx, wave_data in enumerate(datas):
|
||||
wave_axes = self._axes2d[idx]
|
||||
for ax in unique_by_id(wave_axes):
|
||||
N = len(wave_data)
|
||||
max_x = N - 1
|
||||
ax.set_xlim(0, max_x)
|
||||
ax.set_ylim(-1, 1)
|
||||
|
||||
# Setup midlines (depends on max_x and wave_data)
|
||||
midline_color = cfg.midline_color
|
||||
midline_width = pixels(1)
|
||||
|
||||
# zorder=-100 still draws on top of gridlines :(
|
||||
kw = dict(color=midline_color, linewidth=midline_width)
|
||||
if cfg.v_midline:
|
||||
# See Wave.get_around() docstring.
|
||||
# wave_data[N//2] == self[sample], usually > 0.
|
||||
ax.axvline(x=N // 2 - 0.5, **kw)
|
||||
if cfg.h_midline:
|
||||
ax.axhline(y=0, **kw)
|
||||
|
||||
self._save_background()
|
||||
|
||||
# Plot lines over background
|
||||
line_width = pixels(cfg.line_width)
|
||||
|
||||
# Foreach wave
|
||||
for wave_idx, wave_data in enumerate(datas):
|
||||
wave_axes = self._axes2d[wave_idx]
|
||||
wave_lines = []
|
||||
|
||||
# Foreach chan
|
||||
for chan_idx, chan_data in enumerate(wave_data.T):
|
||||
ax = wave_axes[chan_idx]
|
||||
line_color = self._line_params[wave_idx].color
|
||||
chan_line: Line2D = ax.plot(
|
||||
chan_data, color=line_color, linewidth=line_width
|
||||
)[0]
|
||||
wave_lines.append(chan_line)
|
||||
|
||||
self._lines2d.append(wave_lines)
|
||||
self._lines_flat.extend(wave_lines)
|
||||
|
||||
# Draw waveform data
|
||||
else:
|
||||
# Foreach wave
|
||||
for wave_idx, wave_data in enumerate(datas):
|
||||
wave_lines = self._lines2d[wave_idx]
|
||||
# Foreach wave
|
||||
for wave_idx, wave_data in enumerate(datas):
|
||||
wave_lines = lines2d[wave_idx]
|
||||
|
||||
# Foreach chan
|
||||
for chan_idx, chan_data in enumerate(wave_data.T):
|
||||
chan_line = wave_lines[chan_idx]
|
||||
chan_line.set_ydata(chan_data)
|
||||
|
||||
self._redraw_over_background()
|
||||
|
||||
bg_cache: Any # "matplotlib.backends._backend_agg.BufferRegion"
|
||||
|
||||
def _save_background(self) -> None:
|
||||
""" Draw static background. """
|
||||
# https://stackoverflow.com/a/8956211
|
||||
# https://matplotlib.org/api/animation_api.html#funcanimation
|
||||
fig = self._fig
|
||||
|
||||
fig.canvas.draw()
|
||||
self.bg_cache = fig.canvas.copy_from_bbox(fig.bbox)
|
||||
|
||||
def _redraw_over_background(self) -> None:
|
||||
""" Redraw animated elements of the image. """
|
||||
|
||||
canvas: FigureCanvasAgg = self._fig.canvas
|
||||
canvas.restore_region(self.bg_cache)
|
||||
|
||||
for line in self._lines_flat:
|
||||
line.axes.draw_artist(line)
|
||||
|
||||
# https://bastibe.de/2013-05-30-speeding-up-matplotlib.html
|
||||
# thinks fig.canvas.blit(ax.bbox) leaks memory
|
||||
# and fig.canvas.update() works.
|
||||
# Except I found no memory leak...
|
||||
# and update() doesn't exist in FigureCanvasBase when no GUI is present.
|
||||
|
||||
canvas.blit(self._fig.bbox)
|
||||
# Foreach chan
|
||||
for chan_idx, chan_data in enumerate(wave_data.T):
|
||||
chan_line = wave_lines[chan_idx]
|
||||
chan_line.set_ydata(chan_data)
|
||||
|
||||
# Output frames
|
||||
def get_frame(self) -> ByteBuffer:
|
||||
""" Returns ndarray of shape w,h,3. """
|
||||
self._redraw_over_background()
|
||||
|
||||
canvas = self._fig.canvas
|
||||
|
||||
# Agg is the default noninteractive backend except on OSX.
|
||||
|
@ -403,3 +422,26 @@ class MatplotlibRenderer(Renderer):
|
|||
assert len(buffer_rgb) == w * h * RGB_DEPTH
|
||||
|
||||
return buffer_rgb
|
||||
|
||||
# Pre-rendered background
|
||||
bg_cache: Any # "matplotlib.backends._backend_agg.BufferRegion"
|
||||
|
||||
def _save_background(self) -> None:
|
||||
""" Draw static background. """
|
||||
# https://stackoverflow.com/a/8956211
|
||||
# https://matplotlib.org/api/animation_api.html#funcanimation
|
||||
fig = self._fig
|
||||
|
||||
fig.canvas.draw()
|
||||
self.bg_cache = fig.canvas.copy_from_bbox(fig.bbox)
|
||||
|
||||
def _redraw_over_background(self) -> None:
|
||||
""" Redraw animated elements of the image. """
|
||||
|
||||
canvas: FigureCanvasAgg = self._fig.canvas
|
||||
canvas.restore_region(self.bg_cache)
|
||||
|
||||
for artist in self._artists:
|
||||
artist.axes.draw_artist(artist)
|
||||
|
||||
# canvas.blit(self._fig.bbox) is unnecessary when drawing off-screen.
|
||||
|
|
|
@ -127,9 +127,9 @@ def test_config_channel_width_stride(
|
|||
assert _return_nsamp == channel._render_samp
|
||||
assert _subsampling == channel._render_stride
|
||||
|
||||
# Inspect arguments to renderer.render_frame()
|
||||
# Inspect arguments to renderer.update_main_lines()
|
||||
# datas: List[np.ndarray]
|
||||
(datas,), kwargs = renderer.render_frame.call_args
|
||||
(datas,), kwargs = renderer.update_main_lines.call_args
|
||||
render_data = datas[0]
|
||||
assert len(render_data) == channel._render_samp
|
||||
|
||||
|
|
|
@ -221,8 +221,9 @@ def test_renderer_layout():
|
|||
lcfg = LayoutConfig(ncols=2)
|
||||
nplots = 15
|
||||
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
|
||||
r.render_frame([RENDER_Y_ZEROS] * nplots)
|
||||
datas = [RENDER_Y_ZEROS] * nplots
|
||||
r = MatplotlibRenderer(cfg, lcfg, datas, None)
|
||||
r.update_main_lines(datas)
|
||||
layout = r.layout
|
||||
|
||||
# 2 columns, 8 rows
|
||||
|
|
|
@ -84,10 +84,12 @@ def sine440_config():
|
|||
# Calls MatplotlibRenderer, FFmpegOutput, FFmpeg.
|
||||
def test_render_output():
|
||||
""" Ensure rendering to output does not raise exceptions. """
|
||||
renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1, channel_cfgs=None)
|
||||
datas = [RENDER_Y_ZEROS]
|
||||
|
||||
renderer = MatplotlibRenderer(CFG.render, CFG.layout, datas, channel_cfgs=None)
|
||||
out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG)
|
||||
|
||||
renderer.render_frame([RENDER_Y_ZEROS])
|
||||
renderer.update_main_lines(datas)
|
||||
out.write_frame(renderer.get_frame())
|
||||
|
||||
assert out.close() == 0
|
||||
|
@ -166,8 +168,8 @@ def test_corr_terminate_ffplay(Popen, mocker: "pytest_mock.MockFixture"):
|
|||
cfg = sine440_config()
|
||||
corr = CorrScope(cfg, Arguments(".", [FFplayOutputConfig()]))
|
||||
|
||||
render_frame = mocker.patch.object(MatplotlibRenderer, "render_frame")
|
||||
render_frame.side_effect = DummyException()
|
||||
update_main_lines = mocker.patch.object(MatplotlibRenderer, "update_main_lines")
|
||||
update_main_lines.side_effect = DummyException()
|
||||
with pytest.raises(DummyException):
|
||||
corr.play()
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Optional, TYPE_CHECKING, List
|
||||
|
||||
import matplotlib.colors
|
||||
import numpy as np
|
||||
|
@ -54,15 +54,16 @@ def test_default_colors(bg_str, fg_str, grid_str, data):
|
|||
antialiasing=False,
|
||||
)
|
||||
lcfg = LayoutConfig()
|
||||
datas = [data] * NPLOTS
|
||||
|
||||
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, None)
|
||||
verify(r, bg_str, fg_str, grid_str, data)
|
||||
r = MatplotlibRenderer(cfg, lcfg, datas, None)
|
||||
verify(r, bg_str, fg_str, grid_str, datas)
|
||||
|
||||
# Ensure default ChannelConfig(line_color=None) does not override line color
|
||||
chan = ChannelConfig(wav_path="")
|
||||
channels = [chan] * NPLOTS
|
||||
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
|
||||
verify(r, bg_str, fg_str, grid_str, data)
|
||||
r = MatplotlibRenderer(cfg, lcfg, datas, channels)
|
||||
verify(r, bg_str, fg_str, grid_str, datas)
|
||||
|
||||
|
||||
@all_colors
|
||||
|
@ -79,20 +80,25 @@ def test_line_colors(bg_str, fg_str, grid_str, data):
|
|||
antialiasing=False,
|
||||
)
|
||||
lcfg = LayoutConfig()
|
||||
datas = [data] * NPLOTS
|
||||
|
||||
chan = ChannelConfig(wav_path="", line_color=fg_str)
|
||||
channels = [chan] * NPLOTS
|
||||
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
|
||||
verify(r, bg_str, fg_str, grid_str, data)
|
||||
r = MatplotlibRenderer(cfg, lcfg, datas, channels)
|
||||
verify(r, bg_str, fg_str, grid_str, datas)
|
||||
|
||||
|
||||
TOLERANCE = 3
|
||||
|
||||
|
||||
def verify(
|
||||
r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str], data: np.ndarray
|
||||
r: MatplotlibRenderer,
|
||||
bg_str,
|
||||
fg_str,
|
||||
grid_str: Optional[str],
|
||||
datas: List[np.ndarray],
|
||||
):
|
||||
r.render_frame([data] * NPLOTS)
|
||||
r.update_main_lines(datas)
|
||||
frame_colors: np.ndarray = np.frombuffer(r.get_frame(), dtype=np.uint8).reshape(
|
||||
(-1, RGB_DEPTH)
|
||||
)
|
||||
|
@ -107,6 +113,7 @@ def verify(
|
|||
else:
|
||||
grid_u8 = bg_u8
|
||||
|
||||
data = datas[0]
|
||||
assert (data.shape[1] > 1) == (data is RENDER_Y_STEREO)
|
||||
is_stereo = data.shape[1] > 1
|
||||
if is_stereo:
|
||||
|
|
Ładowanie…
Reference in New Issue