kopia lustrzana https://github.com/corrscope/corrscope
Flatten Renderer hierarchy using multiple inheritance
Ensure backend implementations do not inherit from RendererFrontend, since they don't need to know.pull/357/head
rodzic
f3e0b75b70
commit
8fb107fec1
|
|
@ -1,3 +1,12 @@
|
|||
"""
|
||||
Backend implementations should not inherit from RendererFrontend,
|
||||
since they don't need to know.
|
||||
|
||||
Implementation: Multiple inheritance:
|
||||
Renderer inherits from (RendererFrontend, backend implementation).
|
||||
Backend implementation does not know about RendererFrontend.
|
||||
"""
|
||||
|
||||
import enum
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
|
|
@ -301,85 +310,14 @@ class _RendererBackend(ABC):
|
|||
def add_labels(self, labels: List[str]) -> Any:
|
||||
...
|
||||
|
||||
# Primarily used by RendererFrontend, not outside world.
|
||||
@abstractmethod
|
||||
def add_xy_line_mono(
|
||||
def _add_xy_line_mono(
|
||||
self, wave_idx: int, xs: Sequence[float], ys: Sequence[float], stride: int
|
||||
) -> CustomLine:
|
||||
...
|
||||
|
||||
|
||||
class RendererFrontend(_RendererBackend, ABC):
|
||||
def __init__(self, *args, **kwargs):
|
||||
_RendererBackend.__init__(self, *args, **kwargs)
|
||||
self._update_main_lines = None
|
||||
self._custom_lines = {}
|
||||
self._offsetable = defaultdict(list)
|
||||
|
||||
_update_main_lines: Optional[UpdateLines]
|
||||
|
||||
def update_main_lines(self, datas: List[np.ndarray]) -> None:
|
||||
if self._update_main_lines is None:
|
||||
self._update_main_lines = self.add_lines_stereo(datas, self.render_strides)
|
||||
|
||||
self._update_main_lines(datas)
|
||||
|
||||
_custom_lines: Dict[Any, CustomLine]
|
||||
_offsetable: DefaultDict[int, MutableSequence[CustomLine]]
|
||||
|
||||
def update_custom_line(
|
||||
self,
|
||||
name: str,
|
||||
wave_idx: int,
|
||||
stride: int,
|
||||
data: np.ndarray,
|
||||
*,
|
||||
offset: bool = True,
|
||||
):
|
||||
data = data.copy()
|
||||
key = (name, wave_idx)
|
||||
|
||||
if key not in self._custom_lines:
|
||||
line = self._add_line_mono(wave_idx, stride, data)
|
||||
self._custom_lines[key] = line
|
||||
if offset:
|
||||
self._offsetable[wave_idx].append(line)
|
||||
else:
|
||||
line = self._custom_lines[key]
|
||||
|
||||
line.set_ydata(data)
|
||||
|
||||
def update_vline(
|
||||
self, name: str, wave_idx: int, stride: int, x: int, *, offset: bool = True
|
||||
):
|
||||
key = (name, wave_idx)
|
||||
if key not in self._custom_lines:
|
||||
line = self._add_vline_mono(wave_idx, stride)
|
||||
self._custom_lines[key] = line
|
||||
if offset:
|
||||
self._offsetable[wave_idx].append(line)
|
||||
else:
|
||||
line = self._custom_lines[key]
|
||||
|
||||
line.xdata = [x * stride] * 2
|
||||
self._custom_lines[key].set_xdata(line.xdata)
|
||||
|
||||
def offset_viewport(self, wave_idx: int, viewport_offset: float):
|
||||
line_offset = -viewport_offset
|
||||
|
||||
for line in self._offsetable[wave_idx]:
|
||||
line.set_xdata(line.xdata + line_offset * line.stride)
|
||||
|
||||
def _add_line_mono(
|
||||
self, wave_idx: int, stride: int, dummy_data: np.ndarray
|
||||
) -> CustomLine:
|
||||
ys = np.zeros_like(dummy_data)
|
||||
xs = calc_xs(len(ys), stride)
|
||||
return self.add_xy_line_mono(wave_idx, xs, ys, stride)
|
||||
|
||||
def _add_vline_mono(self, wave_idx: int, stride: int) -> CustomLine:
|
||||
return self.add_xy_line_mono(wave_idx, [0, 0], [-1, 1], stride)
|
||||
|
||||
|
||||
# See Wave.get_around() and designNotes.md.
|
||||
# Viewport functions
|
||||
def calc_limits(N: int, viewport_stride: float) -> Tuple[float, float]:
|
||||
|
|
@ -414,9 +352,9 @@ def px_from_points(pt: Point) -> Pixel:
|
|||
return pt * PIXELS_PER_PT
|
||||
|
||||
|
||||
class AbstractMatplotlibRenderer(RendererFrontend, ABC):
|
||||
class AbstractMatplotlibRenderer(_RendererBackend, ABC):
|
||||
"""Matplotlib renderer which can use any backend (agg, mplcairo).
|
||||
To pick a backend, subclass and set canvas_type at the class level.
|
||||
To pick a backend, subclass and set _canvas_type at the class level.
|
||||
"""
|
||||
|
||||
_canvas_type: Type["FigureCanvasBase"] = abstract_classvar
|
||||
|
|
@ -427,7 +365,7 @@ class AbstractMatplotlibRenderer(RendererFrontend, ABC):
|
|||
pass
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
RendererFrontend.__init__(self, *args, **kwargs)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
dict.__setitem__(
|
||||
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
|
||||
|
|
@ -681,7 +619,7 @@ class AbstractMatplotlibRenderer(RendererFrontend, ABC):
|
|||
chan_line = wave_lines[chan_idx]
|
||||
chan_line.set_ydata(chan_data)
|
||||
|
||||
def add_xy_line_mono(
|
||||
def _add_xy_line_mono(
|
||||
self, wave_idx: int, xs: Sequence[float], ys: Sequence[float], stride: int
|
||||
) -> CustomLine:
|
||||
cfg = self.cfg
|
||||
|
|
@ -813,7 +751,7 @@ class MatplotlibAggRenderer(AbstractMatplotlibRenderer):
|
|||
def _canvas_to_bytes(canvas: FigureCanvasAgg) -> ByteBuffer:
|
||||
return canvas.tostring_rgb()
|
||||
|
||||
# implements BaseRenderer
|
||||
# Implements _RendererBackend.
|
||||
bytes_per_pixel = 3
|
||||
ffmpeg_pixel_format = "rgb24"
|
||||
|
||||
|
|
@ -823,4 +761,92 @@ class MatplotlibAggRenderer(AbstractMatplotlibRenderer):
|
|||
return np.array([round(c * 255) for c in to_rgb(c)], dtype=int)
|
||||
|
||||
|
||||
Renderer = MatplotlibAggRenderer
|
||||
class RendererFrontend(_RendererBackend, ABC):
|
||||
"""Wrapper around _RendererBackend implementations, providing a better interface."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self._update_main_lines = None
|
||||
self._custom_lines = {} # type: Dict[Any, CustomLine]
|
||||
self._vlines = {} # type: Dict[Any, CustomLine]
|
||||
self._offsetable = defaultdict(list)
|
||||
|
||||
# Overrides implementations of _RendererBackend.
|
||||
def get_frame(self) -> ByteBuffer:
|
||||
out = super().get_frame()
|
||||
|
||||
for line in self._custom_lines.values():
|
||||
line.set_ydata(0 * line.xdata)
|
||||
|
||||
for line in self._vlines.values():
|
||||
line.set_xdata(0 * line.xdata)
|
||||
return out
|
||||
|
||||
# New methods.
|
||||
_update_main_lines: Optional[UpdateLines]
|
||||
|
||||
def update_main_lines(self, datas: List[np.ndarray]) -> None:
|
||||
if self._update_main_lines is None:
|
||||
self._update_main_lines = self.add_lines_stereo(datas, self.render_strides)
|
||||
|
||||
self._update_main_lines(datas)
|
||||
|
||||
_offsetable: DefaultDict[int, MutableSequence[CustomLine]]
|
||||
|
||||
def update_custom_line(
|
||||
self,
|
||||
name: str,
|
||||
wave_idx: int,
|
||||
stride: int,
|
||||
data: np.ndarray,
|
||||
*,
|
||||
offset: bool = True,
|
||||
):
|
||||
data = data.copy()
|
||||
key = (name, wave_idx)
|
||||
|
||||
if key not in self._custom_lines:
|
||||
line = self._add_line_mono(wave_idx, stride, data)
|
||||
self._custom_lines[key] = line
|
||||
if offset:
|
||||
self._offsetable[wave_idx].append(line)
|
||||
else:
|
||||
line = self._custom_lines[key]
|
||||
|
||||
line.set_ydata(data)
|
||||
|
||||
def update_vline(
|
||||
self, name: str, wave_idx: int, stride: int, x: int, *, offset: bool = True
|
||||
):
|
||||
key = (name, wave_idx)
|
||||
if key not in self._vlines:
|
||||
line = self._add_vline_mono(wave_idx, stride)
|
||||
self._vlines[key] = line
|
||||
if offset:
|
||||
self._offsetable[wave_idx].append(line)
|
||||
else:
|
||||
line = self._vlines[key]
|
||||
|
||||
line.xdata = [x * stride] * 2
|
||||
line.set_xdata(line.xdata)
|
||||
|
||||
def offset_viewport(self, wave_idx: int, viewport_offset: float):
|
||||
line_offset = -viewport_offset
|
||||
|
||||
for line in self._offsetable[wave_idx]:
|
||||
line.set_xdata(line.xdata + line_offset * line.stride)
|
||||
|
||||
def _add_line_mono(
|
||||
self, wave_idx: int, stride: int, dummy_data: np.ndarray
|
||||
) -> CustomLine:
|
||||
ys = np.zeros_like(dummy_data)
|
||||
xs = calc_xs(len(ys), stride)
|
||||
return self._add_xy_line_mono(wave_idx, xs, ys, stride)
|
||||
|
||||
def _add_vline_mono(self, wave_idx: int, stride: int) -> CustomLine:
|
||||
return self._add_xy_line_mono(wave_idx, [0, 0], [-1, 1], stride)
|
||||
|
||||
|
||||
class Renderer(RendererFrontend, MatplotlibAggRenderer):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ from corrscope.renderer import (
|
|||
calc_limits,
|
||||
calc_xs,
|
||||
calc_center,
|
||||
AbstractMatplotlibRenderer,
|
||||
RendererFrontend,
|
||||
)
|
||||
from corrscope.util import perr
|
||||
from corrscope.wave import Flatten
|
||||
|
|
@ -412,7 +414,7 @@ def verify_res_divisor_rounding(
|
|||
cfg.before_preview()
|
||||
|
||||
if speed_hack:
|
||||
mocker.patch.object(Renderer, "_save_background")
|
||||
mocker.patch.object(AbstractMatplotlibRenderer, "_save_background")
|
||||
datas = []
|
||||
else:
|
||||
datas = [RENDER_Y_ZEROS]
|
||||
|
|
@ -478,3 +480,31 @@ def test_renderer_knows_stride(mocker: "pytest_mock.MockFixture", integration: b
|
|||
corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel]
|
||||
)
|
||||
assert renderer.render_strides == [subsampling * width_mul]
|
||||
|
||||
|
||||
# Multiple inheritance tests
|
||||
def test_frontend_overrides_backend(mocker: "pytest_mock.MockFixture"):
|
||||
"""
|
||||
class Renderer inherits from (RendererFrontend, backend).
|
||||
|
||||
RendererFrontend.get_frame() is a wrapper around backend.get_frame()
|
||||
and should override it (RendererFrontend should come first in MRO).
|
||||
|
||||
Make sure RendererFrontend methods overshadow backend methods.
|
||||
"""
|
||||
|
||||
# If RendererFrontend.get_frame() override is removed, delete this entire test.
|
||||
frontend_get_frame = mocker.spy(RendererFrontend, "get_frame")
|
||||
backend_get_frame = mocker.spy(AbstractMatplotlibRenderer, "get_frame")
|
||||
|
||||
corr_cfg = default_config()
|
||||
chan_cfg = ChannelConfig("tests/sine440.wav")
|
||||
channel = Channel(chan_cfg, corr_cfg, channel_idx=0)
|
||||
data = channel.get_render_around(0)
|
||||
|
||||
renderer = Renderer(corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel])
|
||||
renderer.update_main_lines([data])
|
||||
renderer.get_frame()
|
||||
|
||||
assert frontend_get_frame.call_count == 1
|
||||
assert backend_get_frame.call_count == 1
|
||||
|
|
|
|||
Ładowanie…
Reference in New Issue