Refactor renderer API

pull/357/head
nyanpasu64 2019-04-04 04:13:21 -07:00
rodzic 2fc908dd5a
commit 3849f7fe85
7 zmienionych plików z 267 dodań i 210 usunięć

Wyświetl plik

@ -214,8 +214,9 @@ class CorrScope:
yield
def _load_renderer(self) -> Renderer:
dummy_datas = [channel.get_render_around(0) for channel in self.channels]
renderer = MatplotlibRenderer(
self.cfg.render, self.cfg.layout, self.nchan, self.cfg.channels
self.cfg.render, self.cfg.layout, dummy_datas, self.cfg.channels
)
return renderer
@ -293,7 +294,7 @@ class CorrScope:
if not_benchmarking or benchmark_mode >= BenchmarkMode.RENDER:
# Render frame
renderer.render_frame(render_datas)
renderer.update_main_lines(render_datas)
frame_data = renderer.get_frame()
if not_benchmarking or benchmark_mode == BenchmarkMode.OUTPUT:

Wyświetl plik

@ -153,12 +153,16 @@ class RendererLayout:
self.wave_nrow = nrows
self.wave_ncol = ncols
def arrange(self, region_factory: RegionFactory[Region]) -> List[List[Region]]:
def arrange(
self, region_factory: RegionFactory[Region], **kwargs
) -> List[List[Region]]:
"""
(row, column) are fed into region_factory in a row-major order [row][col].
Stereo channel pairs are extracted.
The results are possibly reshaped into column-major order [col][row].
**kwargs -> region_factory(**kwargs).
:return arr[wave][channel] = Region
"""
@ -229,7 +233,7 @@ class RendererLayout:
screen_edges,
wave_edges,
)
region = region_factory(chan_spec)
region = region_factory(chan_spec, **kwargs)
region_chan.append(region)
# Move to next channel position

Wyświetl plik

@ -1,9 +1,9 @@
import os
from abc import ABC, abstractmethod
from typing import Optional, List, TYPE_CHECKING, Any
from typing import Optional, List, TYPE_CHECKING, Any, Callable
import attr
import matplotlib
import matplotlib # do NOT import anything else until we call matplotlib.use().
import matplotlib.colors
import numpy as np
@ -43,6 +43,7 @@ from matplotlib.backends.backend_agg import FigureCanvasAgg
from matplotlib.figure import Figure
if TYPE_CHECKING:
from matplotlib.artist import Artist
from matplotlib.axes import Axes
from matplotlib.lines import Line2D
from corrscope.channel import ChannelConfig
@ -99,18 +100,24 @@ class LineParam:
color: str
UpdateLines = Callable[[List[np.ndarray]], None]
# TODO rename to Plotter
class Renderer(ABC):
def __init__(
self,
cfg: RendererConfig,
lcfg: "LayoutConfig",
nplots: int,
dummy_datas: List[np.ndarray],
channel_cfgs: Optional[List["ChannelConfig"]],
):
self.cfg = cfg
self.lcfg = lcfg
self.nplots = nplots
self.nplots = len(dummy_datas)
assert len(dummy_datas[0].shape) == 2, dummy_datas[0].shape
self.wave_nsamps = [data.shape[0] for data in dummy_datas]
self.wave_nchans = [data.shape[1] for data in dummy_datas]
# Load line colors.
if channel_cfgs is not None:
@ -127,8 +134,16 @@ class Renderer(ABC):
for color in line_colors
]
_update_main_lines: Optional[UpdateLines] = None
def update_main_lines(self, datas: List[np.ndarray]) -> None:
if self._update_main_lines is None:
self._update_main_lines = self.add_lines(datas)
self._update_main_lines(datas)
@abstractmethod
def render_frame(self, datas: List[np.ndarray]) -> None:
def add_lines(self, dummy_datas: List[np.ndarray]) -> UpdateLines:
...
@abstractmethod
@ -175,217 +190,221 @@ class MatplotlibRenderer(Renderer):
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
)
self._fig: "Figure"
self._setup_axes(self.wave_nchans)
# _axes2d[wave][chan] = Axes
self._axes2d: List[List["Axes"]] # set by set_layout()
self._artists: List["Artist"] = []
# _lines2d[wave][chan] = Line2D
self._lines2d: List[List[Line2D]] = []
self._lines_flat: List["Line2D"] = []
_fig: "Figure"
transparent = "#00000000"
# _axes2d[wave][chan] = Axes
# Primary, used to draw oscilloscope lines and gridlines.
_axes2d: List[List["Axes"]] # set by set_layout()
layout: RendererLayout
# _axes_mono[wave] = Axes
# Secondary, used for titles and debug plots.
_axes_mono: List["Axes"]
def _set_layout(self, wave_nchans: List[int]) -> None:
def _setup_axes(self, wave_nchans: List[int]) -> None:
"""
Creates a flat array of Matplotlib Axes, with the new layout.
Opens a window showing the Figure (and Axes).
Inputs: self.cfg, self.fig
Outputs: self.nrows, self.ncols, self.axes
Sets up each Axes with correct region limits.
"""
self.layout = RendererLayout(self.lcfg, wave_nchans)
self.layout_mono = RendererLayout(self.lcfg, [1] * self.nplots)
# Create Axes
# https://matplotlib.org/api/_as_gen/matplotlib.pyplot.subplots.html
if hasattr(self, "_fig"):
raise Exception("I don't currently expect to call _set_layout() twice")
raise Exception("I don't currently expect to call _setup_axes() twice")
# plt.close(self.fig)
grid_color = self.cfg.grid_color
cfg = self.cfg
self._fig = Figure()
FigureCanvasAgg(self._fig)
# RegionFactory
def axes_factory(r: RegionSpec) -> "Axes":
width = 1 / r.ncol
left = r.col / r.ncol
assert 0 <= left < 1
height = 1 / r.nrow
bottom = (r.nrow - r.row - 1) / r.nrow
assert 0 <= bottom < 1
# Disabling xticks/yticks is unnecessary, since we hide Axises.
ax = self._fig.add_axes([left, bottom, width, height], xticks=[], yticks=[])
if grid_color:
# Initialize borders
# Hide Axises
# (drawing them is very slow, and we disable ticks+labels anyway)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# Background color
# ax.patch.set_fill(False) sets _fill=False,
# then calls _set_facecolor(...) "alpha = self._alpha if self._fill else 0".
# It is no faster than below.
ax.set_facecolor(self.transparent)
# Set border colors
for spine in ax.spines.values():
spine.set_color(grid_color)
def hide(key: str):
ax.spines[key].set_visible(False)
# Hide all axes except bottom-right.
hide("top")
hide("left")
# If bottom of screen, hide bottom. If right of screen, hide right.
if r.screen_edges & Edges.Bottom:
hide("bottom")
if r.screen_edges & Edges.Right:
hide("right")
# Dim stereo gridlines
if self.cfg.stereo_grid_opacity > 0:
dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
dim_color[-1] = self.cfg.stereo_grid_opacity
def dim(key: str):
ax.spines[key].set_color(dim_color)
else:
dim = hide
# If not bottom of wave, dim bottom. If not right of wave, dim right.
if not r.wave_edges & Edges.Bottom:
dim("bottom")
if not r.wave_edges & Edges.Right:
dim("right")
else:
ax.set_axis_off()
return ax
# Generate arrangement (using self.lcfg, wave_nchans)
# _axes2d[wave][chan] = Axes
self._axes2d = self.layout.arrange(axes_factory)
# Setup figure geometry
self._fig.set_dpi(DPI)
self._fig.set_size_inches(self.cfg.width / DPI, self.cfg.height / DPI)
FigureCanvasAgg(self._fig)
def render_frame(self, datas: List[np.ndarray]) -> None:
# Setup background
self._fig.set_facecolor(cfg.bg_color)
# Create Axes (using self.lcfg, wave_nchans)
# _axes2d[wave][chan] = Axes
self._axes2d = self.layout.arrange(self._axes_factory)
"""
Adding an axes using the same arguments as a previous axes
currently reuses the earlier instance.
In a future version, a new instance will always be created and returned.
Meanwhile, this warning can be suppressed, and the future behavior ensured,
by passing a unique label to each axes instance.
ax=fig.add_axes(label=) is unused, even if you call ax.legend().
"""
# _axes_mono[wave] = Axes
self._axes_mono = []
# Returns 2D list of [self.nplots][1]Axes.
axes_mono_2d = self.layout_mono.arrange(self._axes_factory, label="mono")
for axes_list in axes_mono_2d:
assert len(axes_list) == 1
self._axes_mono.extend(axes_list)
# Setup axes
for idx, N in enumerate(self.wave_nsamps):
wave_axes = self._axes2d[idx]
max_x = N - 1
def scale_axes(ax: "Axes"):
ax.set_xlim(0, max_x)
ax.set_ylim(-1, 1)
scale_axes(self._axes_mono[idx])
for ax in unique_by_id(wave_axes):
scale_axes(ax)
# Setup midlines (depends on max_x and wave_data)
midline_color = cfg.midline_color
midline_width = pixels(1)
# Not quite sure if midlines or gridlines draw on top
kw = dict(color=midline_color, linewidth=midline_width)
if cfg.v_midline:
# See Wave.get_around() docstring.
# wave_data[N//2] == self[sample], usually > 0.
ax.axvline(x=N // 2 - 0.5, **kw)
if cfg.h_midline:
ax.axhline(y=0, **kw)
self._save_background()
transparent = "#00000000"
# satisfies RegionFactory
def _axes_factory(self, r: RegionSpec, label: str = "") -> "Axes":
grid_color = self.cfg.grid_color
width = 1 / r.ncol
left = r.col / r.ncol
assert 0 <= left < 1
height = 1 / r.nrow
bottom = (r.nrow - r.row - 1) / r.nrow
assert 0 <= bottom < 1
# Disabling xticks/yticks is unnecessary, since we hide Axises.
ax = self._fig.add_axes(
[left, bottom, width, height], xticks=[], yticks=[], label=label
)
if grid_color:
# Initialize borders
# Hide Axises
# (drawing them is very slow, and we disable ticks+labels anyway)
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# Background color
# ax.patch.set_fill(False) sets _fill=False,
# then calls _set_facecolor(...) "alpha = self._alpha if self._fill else 0".
# It is no faster than below.
ax.set_facecolor(self.transparent)
# Set border colors
for spine in ax.spines.values():
spine.set_color(grid_color)
def hide(key: str):
ax.spines[key].set_visible(False)
# Hide all axes except bottom-right.
hide("top")
hide("left")
# If bottom of screen, hide bottom. If right of screen, hide right.
if r.screen_edges & Edges.Bottom:
hide("bottom")
if r.screen_edges & Edges.Right:
hide("right")
# Dim stereo gridlines
if self.cfg.stereo_grid_opacity > 0:
dim_color = matplotlib.colors.to_rgba_array(grid_color)[0]
dim_color[-1] = self.cfg.stereo_grid_opacity
def dim(key: str):
ax.spines[key].set_color(dim_color)
else:
dim = hide
# If not bottom of wave, dim bottom. If not right of wave, dim right.
if not r.wave_edges & Edges.Bottom:
dim("bottom")
if not r.wave_edges & Edges.Right:
dim("right")
else:
ax.set_axis_off()
return ax
# Public API
def add_lines(self, dummy_datas: List[np.ndarray]) -> UpdateLines:
cfg = self.cfg
# Plot lines over background
line_width = pixels(cfg.line_width)
# Foreach wave, plot dummy data.
lines2d = []
for wave_idx, wave_data in enumerate(dummy_datas):
wave_zeros = np.zeros_like(wave_data)
wave_axes = self._axes2d[wave_idx]
wave_lines = []
# Foreach chan
for chan_idx, chan_zeros in enumerate(wave_zeros.T):
ax = wave_axes[chan_idx]
line_color = self._line_params[wave_idx].color
chan_line: Line2D = ax.plot(
chan_zeros, color=line_color, linewidth=line_width
)[0]
wave_lines.append(chan_line)
lines2d.append(wave_lines)
self._artists.extend(wave_lines)
return lambda datas: self._update_lines(lines2d, datas)
@staticmethod
def _update_lines(lines2d: "List[List[Line2D]]", datas: List[np.ndarray]) -> None:
"""
Preconditions:
- lines2d[wave][chan] = Line2D
- datas[wave] = ndarray, [samp][chan] = FLOAT
"""
nplots = len(lines2d)
ndata = len(datas)
if self.nplots != ndata:
if nplots != ndata:
raise ValueError(
f"incorrect data to plot: {self.nplots} plots but {ndata} datas"
f"incorrect data to plot: {nplots} plots but {ndata} dummy_datas"
)
# Initialize axes and draw waveform data
if not self._lines2d:
assert len(datas[0].shape) == 2, datas[0].shape
wave_nchans = [data.shape[1] for data in datas]
self._set_layout(wave_nchans)
cfg = self.cfg
# Setup background/axes
self._fig.set_facecolor(cfg.bg_color)
for idx, wave_data in enumerate(datas):
wave_axes = self._axes2d[idx]
for ax in unique_by_id(wave_axes):
N = len(wave_data)
max_x = N - 1
ax.set_xlim(0, max_x)
ax.set_ylim(-1, 1)
# Setup midlines (depends on max_x and wave_data)
midline_color = cfg.midline_color
midline_width = pixels(1)
# zorder=-100 still draws on top of gridlines :(
kw = dict(color=midline_color, linewidth=midline_width)
if cfg.v_midline:
# See Wave.get_around() docstring.
# wave_data[N//2] == self[sample], usually > 0.
ax.axvline(x=N // 2 - 0.5, **kw)
if cfg.h_midline:
ax.axhline(y=0, **kw)
self._save_background()
# Plot lines over background
line_width = pixels(cfg.line_width)
# Foreach wave
for wave_idx, wave_data in enumerate(datas):
wave_axes = self._axes2d[wave_idx]
wave_lines = []
# Foreach chan
for chan_idx, chan_data in enumerate(wave_data.T):
ax = wave_axes[chan_idx]
line_color = self._line_params[wave_idx].color
chan_line: Line2D = ax.plot(
chan_data, color=line_color, linewidth=line_width
)[0]
wave_lines.append(chan_line)
self._lines2d.append(wave_lines)
self._lines_flat.extend(wave_lines)
# Draw waveform data
else:
# Foreach wave
for wave_idx, wave_data in enumerate(datas):
wave_lines = self._lines2d[wave_idx]
# Foreach wave
for wave_idx, wave_data in enumerate(datas):
wave_lines = lines2d[wave_idx]
# Foreach chan
for chan_idx, chan_data in enumerate(wave_data.T):
chan_line = wave_lines[chan_idx]
chan_line.set_ydata(chan_data)
self._redraw_over_background()
bg_cache: Any # "matplotlib.backends._backend_agg.BufferRegion"
def _save_background(self) -> None:
""" Draw static background. """
# https://stackoverflow.com/a/8956211
# https://matplotlib.org/api/animation_api.html#funcanimation
fig = self._fig
fig.canvas.draw()
self.bg_cache = fig.canvas.copy_from_bbox(fig.bbox)
def _redraw_over_background(self) -> None:
""" Redraw animated elements of the image. """
canvas: FigureCanvasAgg = self._fig.canvas
canvas.restore_region(self.bg_cache)
for line in self._lines_flat:
line.axes.draw_artist(line)
# https://bastibe.de/2013-05-30-speeding-up-matplotlib.html
# thinks fig.canvas.blit(ax.bbox) leaks memory
# and fig.canvas.update() works.
# Except I found no memory leak...
# and update() doesn't exist in FigureCanvasBase when no GUI is present.
canvas.blit(self._fig.bbox)
# Foreach chan
for chan_idx, chan_data in enumerate(wave_data.T):
chan_line = wave_lines[chan_idx]
chan_line.set_ydata(chan_data)
# Output frames
def get_frame(self) -> ByteBuffer:
""" Returns ndarray of shape w,h,3. """
self._redraw_over_background()
canvas = self._fig.canvas
# Agg is the default noninteractive backend except on OSX.
@ -403,3 +422,26 @@ class MatplotlibRenderer(Renderer):
assert len(buffer_rgb) == w * h * RGB_DEPTH
return buffer_rgb
# Pre-rendered background
bg_cache: Any # "matplotlib.backends._backend_agg.BufferRegion"
def _save_background(self) -> None:
""" Draw static background. """
# https://stackoverflow.com/a/8956211
# https://matplotlib.org/api/animation_api.html#funcanimation
fig = self._fig
fig.canvas.draw()
self.bg_cache = fig.canvas.copy_from_bbox(fig.bbox)
def _redraw_over_background(self) -> None:
""" Redraw animated elements of the image. """
canvas: FigureCanvasAgg = self._fig.canvas
canvas.restore_region(self.bg_cache)
for artist in self._artists:
artist.axes.draw_artist(artist)
# canvas.blit(self._fig.bbox) is unnecessary when drawing off-screen.

Wyświetl plik

@ -127,9 +127,9 @@ def test_config_channel_width_stride(
assert _return_nsamp == channel._render_samp
assert _subsampling == channel._render_stride
# Inspect arguments to renderer.render_frame()
# Inspect arguments to renderer.update_main_lines()
# datas: List[np.ndarray]
(datas,), kwargs = renderer.render_frame.call_args
(datas,), kwargs = renderer.update_main_lines.call_args
render_data = datas[0]
assert len(render_data) == channel._render_samp

Wyświetl plik

@ -221,8 +221,9 @@ def test_renderer_layout():
lcfg = LayoutConfig(ncols=2)
nplots = 15
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
r.render_frame([RENDER_Y_ZEROS] * nplots)
datas = [RENDER_Y_ZEROS] * nplots
r = MatplotlibRenderer(cfg, lcfg, datas, None)
r.update_main_lines(datas)
layout = r.layout
# 2 columns, 8 rows

Wyświetl plik

@ -84,10 +84,12 @@ def sine440_config():
# Calls MatplotlibRenderer, FFmpegOutput, FFmpeg.
def test_render_output():
""" Ensure rendering to output does not raise exceptions. """
renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1, channel_cfgs=None)
datas = [RENDER_Y_ZEROS]
renderer = MatplotlibRenderer(CFG.render, CFG.layout, datas, channel_cfgs=None)
out: FFmpegOutput = NULL_FFMPEG_OUTPUT(CFG)
renderer.render_frame([RENDER_Y_ZEROS])
renderer.update_main_lines(datas)
out.write_frame(renderer.get_frame())
assert out.close() == 0
@ -166,8 +168,8 @@ def test_corr_terminate_ffplay(Popen, mocker: "pytest_mock.MockFixture"):
cfg = sine440_config()
corr = CorrScope(cfg, Arguments(".", [FFplayOutputConfig()]))
render_frame = mocker.patch.object(MatplotlibRenderer, "render_frame")
render_frame.side_effect = DummyException()
update_main_lines = mocker.patch.object(MatplotlibRenderer, "update_main_lines")
update_main_lines.side_effect = DummyException()
with pytest.raises(DummyException):
corr.play()

Wyświetl plik

@ -1,4 +1,4 @@
from typing import Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING, List
import matplotlib.colors
import numpy as np
@ -54,15 +54,16 @@ def test_default_colors(bg_str, fg_str, grid_str, data):
antialiasing=False,
)
lcfg = LayoutConfig()
datas = [data] * NPLOTS
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, None)
verify(r, bg_str, fg_str, grid_str, data)
r = MatplotlibRenderer(cfg, lcfg, datas, None)
verify(r, bg_str, fg_str, grid_str, datas)
# Ensure default ChannelConfig(line_color=None) does not override line color
chan = ChannelConfig(wav_path="")
channels = [chan] * NPLOTS
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
verify(r, bg_str, fg_str, grid_str, data)
r = MatplotlibRenderer(cfg, lcfg, datas, channels)
verify(r, bg_str, fg_str, grid_str, datas)
@all_colors
@ -79,20 +80,25 @@ def test_line_colors(bg_str, fg_str, grid_str, data):
antialiasing=False,
)
lcfg = LayoutConfig()
datas = [data] * NPLOTS
chan = ChannelConfig(wav_path="", line_color=fg_str)
channels = [chan] * NPLOTS
r = MatplotlibRenderer(cfg, lcfg, NPLOTS, channels)
verify(r, bg_str, fg_str, grid_str, data)
r = MatplotlibRenderer(cfg, lcfg, datas, channels)
verify(r, bg_str, fg_str, grid_str, datas)
TOLERANCE = 3
def verify(
r: MatplotlibRenderer, bg_str, fg_str, grid_str: Optional[str], data: np.ndarray
r: MatplotlibRenderer,
bg_str,
fg_str,
grid_str: Optional[str],
datas: List[np.ndarray],
):
r.render_frame([data] * NPLOTS)
r.update_main_lines(datas)
frame_colors: np.ndarray = np.frombuffer(r.get_frame(), dtype=np.uint8).reshape(
(-1, RGB_DEPTH)
)
@ -107,6 +113,7 @@ def verify(
else:
grid_u8 = bg_u8
data = datas[0]
assert (data.shape[1] > 1) == (data is RENDER_Y_STEREO)
is_stereo = data.shape[1] > 1
if is_stereo: