Flatten Renderer hierarchy using multiple inheritance

Ensure backend implementations do not inherit from RendererFrontend,
since they don't need to know.
pull/357/head
nyanpasu64 2019-06-09 11:41:29 -07:00
rodzic f3e0b75b70
commit 8fb107fec1
2 zmienionych plików z 136 dodań i 80 usunięć

Wyświetl plik

@ -1,3 +1,12 @@
"""
Backend implementations should not inherit from RendererFrontend,
since they don't need to know.
Implementation: Multiple inheritance:
Renderer inherits from (RendererFrontend, backend implementation).
Backend implementation does not know about RendererFrontend.
"""
import enum import enum
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -301,85 +310,14 @@ class _RendererBackend(ABC):
def add_labels(self, labels: List[str]) -> Any: def add_labels(self, labels: List[str]) -> Any:
... ...
# Primarily used by RendererFrontend, not outside world.
@abstractmethod @abstractmethod
def add_xy_line_mono( def _add_xy_line_mono(
self, wave_idx: int, xs: Sequence[float], ys: Sequence[float], stride: int self, wave_idx: int, xs: Sequence[float], ys: Sequence[float], stride: int
) -> CustomLine: ) -> CustomLine:
... ...
class RendererFrontend(_RendererBackend, ABC):
def __init__(self, *args, **kwargs):
_RendererBackend.__init__(self, *args, **kwargs)
self._update_main_lines = None
self._custom_lines = {}
self._offsetable = defaultdict(list)
_update_main_lines: Optional[UpdateLines]
def update_main_lines(self, datas: List[np.ndarray]) -> None:
if self._update_main_lines is None:
self._update_main_lines = self.add_lines_stereo(datas, self.render_strides)
self._update_main_lines(datas)
_custom_lines: Dict[Any, CustomLine]
_offsetable: DefaultDict[int, MutableSequence[CustomLine]]
def update_custom_line(
self,
name: str,
wave_idx: int,
stride: int,
data: np.ndarray,
*,
offset: bool = True,
):
data = data.copy()
key = (name, wave_idx)
if key not in self._custom_lines:
line = self._add_line_mono(wave_idx, stride, data)
self._custom_lines[key] = line
if offset:
self._offsetable[wave_idx].append(line)
else:
line = self._custom_lines[key]
line.set_ydata(data)
def update_vline(
self, name: str, wave_idx: int, stride: int, x: int, *, offset: bool = True
):
key = (name, wave_idx)
if key not in self._custom_lines:
line = self._add_vline_mono(wave_idx, stride)
self._custom_lines[key] = line
if offset:
self._offsetable[wave_idx].append(line)
else:
line = self._custom_lines[key]
line.xdata = [x * stride] * 2
self._custom_lines[key].set_xdata(line.xdata)
def offset_viewport(self, wave_idx: int, viewport_offset: float):
line_offset = -viewport_offset
for line in self._offsetable[wave_idx]:
line.set_xdata(line.xdata + line_offset * line.stride)
def _add_line_mono(
self, wave_idx: int, stride: int, dummy_data: np.ndarray
) -> CustomLine:
ys = np.zeros_like(dummy_data)
xs = calc_xs(len(ys), stride)
return self.add_xy_line_mono(wave_idx, xs, ys, stride)
def _add_vline_mono(self, wave_idx: int, stride: int) -> CustomLine:
return self.add_xy_line_mono(wave_idx, [0, 0], [-1, 1], stride)
# See Wave.get_around() and designNotes.md. # See Wave.get_around() and designNotes.md.
# Viewport functions # Viewport functions
def calc_limits(N: int, viewport_stride: float) -> Tuple[float, float]: def calc_limits(N: int, viewport_stride: float) -> Tuple[float, float]:
@ -414,9 +352,9 @@ def px_from_points(pt: Point) -> Pixel:
return pt * PIXELS_PER_PT return pt * PIXELS_PER_PT
class AbstractMatplotlibRenderer(RendererFrontend, ABC): class AbstractMatplotlibRenderer(_RendererBackend, ABC):
"""Matplotlib renderer which can use any backend (agg, mplcairo). """Matplotlib renderer which can use any backend (agg, mplcairo).
To pick a backend, subclass and set canvas_type at the class level. To pick a backend, subclass and set _canvas_type at the class level.
""" """
_canvas_type: Type["FigureCanvasBase"] = abstract_classvar _canvas_type: Type["FigureCanvasBase"] = abstract_classvar
@ -427,7 +365,7 @@ class AbstractMatplotlibRenderer(RendererFrontend, ABC):
pass pass
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
RendererFrontend.__init__(self, *args, **kwargs) super().__init__(*args, **kwargs)
dict.__setitem__( dict.__setitem__(
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
@ -681,7 +619,7 @@ class AbstractMatplotlibRenderer(RendererFrontend, ABC):
chan_line = wave_lines[chan_idx] chan_line = wave_lines[chan_idx]
chan_line.set_ydata(chan_data) chan_line.set_ydata(chan_data)
def add_xy_line_mono( def _add_xy_line_mono(
self, wave_idx: int, xs: Sequence[float], ys: Sequence[float], stride: int self, wave_idx: int, xs: Sequence[float], ys: Sequence[float], stride: int
) -> CustomLine: ) -> CustomLine:
cfg = self.cfg cfg = self.cfg
@ -813,7 +751,7 @@ class MatplotlibAggRenderer(AbstractMatplotlibRenderer):
def _canvas_to_bytes(canvas: FigureCanvasAgg) -> ByteBuffer: def _canvas_to_bytes(canvas: FigureCanvasAgg) -> ByteBuffer:
return canvas.tostring_rgb() return canvas.tostring_rgb()
# implements BaseRenderer # Implements _RendererBackend.
bytes_per_pixel = 3 bytes_per_pixel = 3
ffmpeg_pixel_format = "rgb24" ffmpeg_pixel_format = "rgb24"
@ -823,4 +761,92 @@ class MatplotlibAggRenderer(AbstractMatplotlibRenderer):
return np.array([round(c * 255) for c in to_rgb(c)], dtype=int) return np.array([round(c * 255) for c in to_rgb(c)], dtype=int)
Renderer = MatplotlibAggRenderer class RendererFrontend(_RendererBackend, ABC):
"""Wrapper around _RendererBackend implementations, providing a better interface."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._update_main_lines = None
self._custom_lines = {} # type: Dict[Any, CustomLine]
self._vlines = {} # type: Dict[Any, CustomLine]
self._offsetable = defaultdict(list)
# Overrides implementations of _RendererBackend.
def get_frame(self) -> ByteBuffer:
out = super().get_frame()
for line in self._custom_lines.values():
line.set_ydata(0 * line.xdata)
for line in self._vlines.values():
line.set_xdata(0 * line.xdata)
return out
# New methods.
_update_main_lines: Optional[UpdateLines]
def update_main_lines(self, datas: List[np.ndarray]) -> None:
if self._update_main_lines is None:
self._update_main_lines = self.add_lines_stereo(datas, self.render_strides)
self._update_main_lines(datas)
_offsetable: DefaultDict[int, MutableSequence[CustomLine]]
def update_custom_line(
self,
name: str,
wave_idx: int,
stride: int,
data: np.ndarray,
*,
offset: bool = True,
):
data = data.copy()
key = (name, wave_idx)
if key not in self._custom_lines:
line = self._add_line_mono(wave_idx, stride, data)
self._custom_lines[key] = line
if offset:
self._offsetable[wave_idx].append(line)
else:
line = self._custom_lines[key]
line.set_ydata(data)
def update_vline(
self, name: str, wave_idx: int, stride: int, x: int, *, offset: bool = True
):
key = (name, wave_idx)
if key not in self._vlines:
line = self._add_vline_mono(wave_idx, stride)
self._vlines[key] = line
if offset:
self._offsetable[wave_idx].append(line)
else:
line = self._vlines[key]
line.xdata = [x * stride] * 2
line.set_xdata(line.xdata)
def offset_viewport(self, wave_idx: int, viewport_offset: float):
line_offset = -viewport_offset
for line in self._offsetable[wave_idx]:
line.set_xdata(line.xdata + line_offset * line.stride)
def _add_line_mono(
self, wave_idx: int, stride: int, dummy_data: np.ndarray
) -> CustomLine:
ys = np.zeros_like(dummy_data)
xs = calc_xs(len(ys), stride)
return self._add_xy_line_mono(wave_idx, xs, ys, stride)
def _add_vline_mono(self, wave_idx: int, stride: int) -> CustomLine:
return self._add_xy_line_mono(wave_idx, [0, 0], [-1, 1], stride)
class Renderer(RendererFrontend, MatplotlibAggRenderer):
pass

Wyświetl plik

@ -18,6 +18,8 @@ from corrscope.renderer import (
calc_limits, calc_limits,
calc_xs, calc_xs,
calc_center, calc_center,
AbstractMatplotlibRenderer,
RendererFrontend,
) )
from corrscope.util import perr from corrscope.util import perr
from corrscope.wave import Flatten from corrscope.wave import Flatten
@ -412,7 +414,7 @@ def verify_res_divisor_rounding(
cfg.before_preview() cfg.before_preview()
if speed_hack: if speed_hack:
mocker.patch.object(Renderer, "_save_background") mocker.patch.object(AbstractMatplotlibRenderer, "_save_background")
datas = [] datas = []
else: else:
datas = [RENDER_Y_ZEROS] datas = [RENDER_Y_ZEROS]
@ -478,3 +480,31 @@ def test_renderer_knows_stride(mocker: "pytest_mock.MockFixture", integration: b
corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel] corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel]
) )
assert renderer.render_strides == [subsampling * width_mul] assert renderer.render_strides == [subsampling * width_mul]
# Multiple inheritance tests
def test_frontend_overrides_backend(mocker: "pytest_mock.MockFixture"):
"""
class Renderer inherits from (RendererFrontend, backend).
RendererFrontend.get_frame() is a wrapper around backend.get_frame()
and should override it (RendererFrontend should come first in MRO).
Make sure RendererFrontend methods overshadow backend methods.
"""
# If RendererFrontend.get_frame() override is removed, delete this entire test.
frontend_get_frame = mocker.spy(RendererFrontend, "get_frame")
backend_get_frame = mocker.spy(AbstractMatplotlibRenderer, "get_frame")
corr_cfg = default_config()
chan_cfg = ChannelConfig("tests/sine440.wav")
channel = Channel(chan_cfg, corr_cfg, channel_idx=0)
data = channel.get_render_around(0)
renderer = Renderer(corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel])
renderer.update_main_lines([data])
renderer.get_frame()
assert frontend_get_frame.call_count == 1
assert backend_get_frame.call_count == 1