kopia lustrzana https://github.com/corrscope/corrscope
Flatten Renderer hierarchy using multiple inheritance
Ensure backend implementations do not inherit from RendererFrontend, since they don't need to know.pull/357/head
rodzic
f3e0b75b70
commit
8fb107fec1
|
|
@ -1,3 +1,12 @@
|
||||||
|
"""
|
||||||
|
Backend implementations should not inherit from RendererFrontend,
|
||||||
|
since they don't need to know.
|
||||||
|
|
||||||
|
Implementation: Multiple inheritance:
|
||||||
|
Renderer inherits from (RendererFrontend, backend implementation).
|
||||||
|
Backend implementation does not know about RendererFrontend.
|
||||||
|
"""
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import os
|
import os
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
@ -301,85 +310,14 @@ class _RendererBackend(ABC):
|
||||||
def add_labels(self, labels: List[str]) -> Any:
|
def add_labels(self, labels: List[str]) -> Any:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
# Primarily used by RendererFrontend, not outside world.
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_xy_line_mono(
|
def _add_xy_line_mono(
|
||||||
self, wave_idx: int, xs: Sequence[float], ys: Sequence[float], stride: int
|
self, wave_idx: int, xs: Sequence[float], ys: Sequence[float], stride: int
|
||||||
) -> CustomLine:
|
) -> CustomLine:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class RendererFrontend(_RendererBackend, ABC):
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
_RendererBackend.__init__(self, *args, **kwargs)
|
|
||||||
self._update_main_lines = None
|
|
||||||
self._custom_lines = {}
|
|
||||||
self._offsetable = defaultdict(list)
|
|
||||||
|
|
||||||
_update_main_lines: Optional[UpdateLines]
|
|
||||||
|
|
||||||
def update_main_lines(self, datas: List[np.ndarray]) -> None:
|
|
||||||
if self._update_main_lines is None:
|
|
||||||
self._update_main_lines = self.add_lines_stereo(datas, self.render_strides)
|
|
||||||
|
|
||||||
self._update_main_lines(datas)
|
|
||||||
|
|
||||||
_custom_lines: Dict[Any, CustomLine]
|
|
||||||
_offsetable: DefaultDict[int, MutableSequence[CustomLine]]
|
|
||||||
|
|
||||||
def update_custom_line(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
wave_idx: int,
|
|
||||||
stride: int,
|
|
||||||
data: np.ndarray,
|
|
||||||
*,
|
|
||||||
offset: bool = True,
|
|
||||||
):
|
|
||||||
data = data.copy()
|
|
||||||
key = (name, wave_idx)
|
|
||||||
|
|
||||||
if key not in self._custom_lines:
|
|
||||||
line = self._add_line_mono(wave_idx, stride, data)
|
|
||||||
self._custom_lines[key] = line
|
|
||||||
if offset:
|
|
||||||
self._offsetable[wave_idx].append(line)
|
|
||||||
else:
|
|
||||||
line = self._custom_lines[key]
|
|
||||||
|
|
||||||
line.set_ydata(data)
|
|
||||||
|
|
||||||
def update_vline(
|
|
||||||
self, name: str, wave_idx: int, stride: int, x: int, *, offset: bool = True
|
|
||||||
):
|
|
||||||
key = (name, wave_idx)
|
|
||||||
if key not in self._custom_lines:
|
|
||||||
line = self._add_vline_mono(wave_idx, stride)
|
|
||||||
self._custom_lines[key] = line
|
|
||||||
if offset:
|
|
||||||
self._offsetable[wave_idx].append(line)
|
|
||||||
else:
|
|
||||||
line = self._custom_lines[key]
|
|
||||||
|
|
||||||
line.xdata = [x * stride] * 2
|
|
||||||
self._custom_lines[key].set_xdata(line.xdata)
|
|
||||||
|
|
||||||
def offset_viewport(self, wave_idx: int, viewport_offset: float):
|
|
||||||
line_offset = -viewport_offset
|
|
||||||
|
|
||||||
for line in self._offsetable[wave_idx]:
|
|
||||||
line.set_xdata(line.xdata + line_offset * line.stride)
|
|
||||||
|
|
||||||
def _add_line_mono(
|
|
||||||
self, wave_idx: int, stride: int, dummy_data: np.ndarray
|
|
||||||
) -> CustomLine:
|
|
||||||
ys = np.zeros_like(dummy_data)
|
|
||||||
xs = calc_xs(len(ys), stride)
|
|
||||||
return self.add_xy_line_mono(wave_idx, xs, ys, stride)
|
|
||||||
|
|
||||||
def _add_vline_mono(self, wave_idx: int, stride: int) -> CustomLine:
|
|
||||||
return self.add_xy_line_mono(wave_idx, [0, 0], [-1, 1], stride)
|
|
||||||
|
|
||||||
|
|
||||||
# See Wave.get_around() and designNotes.md.
|
# See Wave.get_around() and designNotes.md.
|
||||||
# Viewport functions
|
# Viewport functions
|
||||||
def calc_limits(N: int, viewport_stride: float) -> Tuple[float, float]:
|
def calc_limits(N: int, viewport_stride: float) -> Tuple[float, float]:
|
||||||
|
|
@ -414,9 +352,9 @@ def px_from_points(pt: Point) -> Pixel:
|
||||||
return pt * PIXELS_PER_PT
|
return pt * PIXELS_PER_PT
|
||||||
|
|
||||||
|
|
||||||
class AbstractMatplotlibRenderer(RendererFrontend, ABC):
|
class AbstractMatplotlibRenderer(_RendererBackend, ABC):
|
||||||
"""Matplotlib renderer which can use any backend (agg, mplcairo).
|
"""Matplotlib renderer which can use any backend (agg, mplcairo).
|
||||||
To pick a backend, subclass and set canvas_type at the class level.
|
To pick a backend, subclass and set _canvas_type at the class level.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_canvas_type: Type["FigureCanvasBase"] = abstract_classvar
|
_canvas_type: Type["FigureCanvasBase"] = abstract_classvar
|
||||||
|
|
@ -427,7 +365,7 @@ class AbstractMatplotlibRenderer(RendererFrontend, ABC):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
RendererFrontend.__init__(self, *args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
dict.__setitem__(
|
dict.__setitem__(
|
||||||
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
|
matplotlib.rcParams, "lines.antialiased", self.cfg.antialiasing
|
||||||
|
|
@ -681,7 +619,7 @@ class AbstractMatplotlibRenderer(RendererFrontend, ABC):
|
||||||
chan_line = wave_lines[chan_idx]
|
chan_line = wave_lines[chan_idx]
|
||||||
chan_line.set_ydata(chan_data)
|
chan_line.set_ydata(chan_data)
|
||||||
|
|
||||||
def add_xy_line_mono(
|
def _add_xy_line_mono(
|
||||||
self, wave_idx: int, xs: Sequence[float], ys: Sequence[float], stride: int
|
self, wave_idx: int, xs: Sequence[float], ys: Sequence[float], stride: int
|
||||||
) -> CustomLine:
|
) -> CustomLine:
|
||||||
cfg = self.cfg
|
cfg = self.cfg
|
||||||
|
|
@ -813,7 +751,7 @@ class MatplotlibAggRenderer(AbstractMatplotlibRenderer):
|
||||||
def _canvas_to_bytes(canvas: FigureCanvasAgg) -> ByteBuffer:
|
def _canvas_to_bytes(canvas: FigureCanvasAgg) -> ByteBuffer:
|
||||||
return canvas.tostring_rgb()
|
return canvas.tostring_rgb()
|
||||||
|
|
||||||
# implements BaseRenderer
|
# Implements _RendererBackend.
|
||||||
bytes_per_pixel = 3
|
bytes_per_pixel = 3
|
||||||
ffmpeg_pixel_format = "rgb24"
|
ffmpeg_pixel_format = "rgb24"
|
||||||
|
|
||||||
|
|
@ -823,4 +761,92 @@ class MatplotlibAggRenderer(AbstractMatplotlibRenderer):
|
||||||
return np.array([round(c * 255) for c in to_rgb(c)], dtype=int)
|
return np.array([round(c * 255) for c in to_rgb(c)], dtype=int)
|
||||||
|
|
||||||
|
|
||||||
Renderer = MatplotlibAggRenderer
|
class RendererFrontend(_RendererBackend, ABC):
|
||||||
|
"""Wrapper around _RendererBackend implementations, providing a better interface."""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
self._update_main_lines = None
|
||||||
|
self._custom_lines = {} # type: Dict[Any, CustomLine]
|
||||||
|
self._vlines = {} # type: Dict[Any, CustomLine]
|
||||||
|
self._offsetable = defaultdict(list)
|
||||||
|
|
||||||
|
# Overrides implementations of _RendererBackend.
|
||||||
|
def get_frame(self) -> ByteBuffer:
|
||||||
|
out = super().get_frame()
|
||||||
|
|
||||||
|
for line in self._custom_lines.values():
|
||||||
|
line.set_ydata(0 * line.xdata)
|
||||||
|
|
||||||
|
for line in self._vlines.values():
|
||||||
|
line.set_xdata(0 * line.xdata)
|
||||||
|
return out
|
||||||
|
|
||||||
|
# New methods.
|
||||||
|
_update_main_lines: Optional[UpdateLines]
|
||||||
|
|
||||||
|
def update_main_lines(self, datas: List[np.ndarray]) -> None:
|
||||||
|
if self._update_main_lines is None:
|
||||||
|
self._update_main_lines = self.add_lines_stereo(datas, self.render_strides)
|
||||||
|
|
||||||
|
self._update_main_lines(datas)
|
||||||
|
|
||||||
|
_offsetable: DefaultDict[int, MutableSequence[CustomLine]]
|
||||||
|
|
||||||
|
def update_custom_line(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
wave_idx: int,
|
||||||
|
stride: int,
|
||||||
|
data: np.ndarray,
|
||||||
|
*,
|
||||||
|
offset: bool = True,
|
||||||
|
):
|
||||||
|
data = data.copy()
|
||||||
|
key = (name, wave_idx)
|
||||||
|
|
||||||
|
if key not in self._custom_lines:
|
||||||
|
line = self._add_line_mono(wave_idx, stride, data)
|
||||||
|
self._custom_lines[key] = line
|
||||||
|
if offset:
|
||||||
|
self._offsetable[wave_idx].append(line)
|
||||||
|
else:
|
||||||
|
line = self._custom_lines[key]
|
||||||
|
|
||||||
|
line.set_ydata(data)
|
||||||
|
|
||||||
|
def update_vline(
|
||||||
|
self, name: str, wave_idx: int, stride: int, x: int, *, offset: bool = True
|
||||||
|
):
|
||||||
|
key = (name, wave_idx)
|
||||||
|
if key not in self._vlines:
|
||||||
|
line = self._add_vline_mono(wave_idx, stride)
|
||||||
|
self._vlines[key] = line
|
||||||
|
if offset:
|
||||||
|
self._offsetable[wave_idx].append(line)
|
||||||
|
else:
|
||||||
|
line = self._vlines[key]
|
||||||
|
|
||||||
|
line.xdata = [x * stride] * 2
|
||||||
|
line.set_xdata(line.xdata)
|
||||||
|
|
||||||
|
def offset_viewport(self, wave_idx: int, viewport_offset: float):
|
||||||
|
line_offset = -viewport_offset
|
||||||
|
|
||||||
|
for line in self._offsetable[wave_idx]:
|
||||||
|
line.set_xdata(line.xdata + line_offset * line.stride)
|
||||||
|
|
||||||
|
def _add_line_mono(
|
||||||
|
self, wave_idx: int, stride: int, dummy_data: np.ndarray
|
||||||
|
) -> CustomLine:
|
||||||
|
ys = np.zeros_like(dummy_data)
|
||||||
|
xs = calc_xs(len(ys), stride)
|
||||||
|
return self._add_xy_line_mono(wave_idx, xs, ys, stride)
|
||||||
|
|
||||||
|
def _add_vline_mono(self, wave_idx: int, stride: int) -> CustomLine:
|
||||||
|
return self._add_xy_line_mono(wave_idx, [0, 0], [-1, 1], stride)
|
||||||
|
|
||||||
|
|
||||||
|
class Renderer(RendererFrontend, MatplotlibAggRenderer):
|
||||||
|
pass
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ from corrscope.renderer import (
|
||||||
calc_limits,
|
calc_limits,
|
||||||
calc_xs,
|
calc_xs,
|
||||||
calc_center,
|
calc_center,
|
||||||
|
AbstractMatplotlibRenderer,
|
||||||
|
RendererFrontend,
|
||||||
)
|
)
|
||||||
from corrscope.util import perr
|
from corrscope.util import perr
|
||||||
from corrscope.wave import Flatten
|
from corrscope.wave import Flatten
|
||||||
|
|
@ -412,7 +414,7 @@ def verify_res_divisor_rounding(
|
||||||
cfg.before_preview()
|
cfg.before_preview()
|
||||||
|
|
||||||
if speed_hack:
|
if speed_hack:
|
||||||
mocker.patch.object(Renderer, "_save_background")
|
mocker.patch.object(AbstractMatplotlibRenderer, "_save_background")
|
||||||
datas = []
|
datas = []
|
||||||
else:
|
else:
|
||||||
datas = [RENDER_Y_ZEROS]
|
datas = [RENDER_Y_ZEROS]
|
||||||
|
|
@ -478,3 +480,31 @@ def test_renderer_knows_stride(mocker: "pytest_mock.MockFixture", integration: b
|
||||||
corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel]
|
corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel]
|
||||||
)
|
)
|
||||||
assert renderer.render_strides == [subsampling * width_mul]
|
assert renderer.render_strides == [subsampling * width_mul]
|
||||||
|
|
||||||
|
|
||||||
|
# Multiple inheritance tests
|
||||||
|
def test_frontend_overrides_backend(mocker: "pytest_mock.MockFixture"):
|
||||||
|
"""
|
||||||
|
class Renderer inherits from (RendererFrontend, backend).
|
||||||
|
|
||||||
|
RendererFrontend.get_frame() is a wrapper around backend.get_frame()
|
||||||
|
and should override it (RendererFrontend should come first in MRO).
|
||||||
|
|
||||||
|
Make sure RendererFrontend methods overshadow backend methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# If RendererFrontend.get_frame() override is removed, delete this entire test.
|
||||||
|
frontend_get_frame = mocker.spy(RendererFrontend, "get_frame")
|
||||||
|
backend_get_frame = mocker.spy(AbstractMatplotlibRenderer, "get_frame")
|
||||||
|
|
||||||
|
corr_cfg = default_config()
|
||||||
|
chan_cfg = ChannelConfig("tests/sine440.wav")
|
||||||
|
channel = Channel(chan_cfg, corr_cfg, channel_idx=0)
|
||||||
|
data = channel.get_render_around(0)
|
||||||
|
|
||||||
|
renderer = Renderer(corr_cfg.render, corr_cfg.layout, [data], [chan_cfg], [channel])
|
||||||
|
renderer.update_main_lines([data])
|
||||||
|
renderer.get_frame()
|
||||||
|
|
||||||
|
assert frontend_get_frame.call_count == 1
|
||||||
|
assert backend_get_frame.call_count == 1
|
||||||
|
|
|
||||||
Ładowanie…
Reference in New Issue