kopia lustrzana https://github.com/corrscope/corrscope
Merge branch 'render-refactor'
commit
7ce15a7f89
|
@ -0,0 +1,87 @@
|
|||
from typing import Optional, TypeVar, Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ovgenpy.config import register_config
|
||||
from ovgenpy.util import ceildiv
|
||||
|
||||
|
||||
@register_config(always_dump='orientation')
|
||||
class LayoutConfig:
|
||||
orientation: str = 'h'
|
||||
nrows: Optional[int] = None
|
||||
ncols: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.nrows:
|
||||
self.nrows = None
|
||||
if not self.ncols:
|
||||
self.ncols = None
|
||||
|
||||
if self.nrows and self.ncols:
|
||||
raise ValueError('cannot manually assign both nrows and ncols')
|
||||
|
||||
if not self.nrows and not self.ncols:
|
||||
self.ncols = 1
|
||||
|
||||
|
||||
Region = TypeVar('Region')
|
||||
RegionFactory = Callable[[int, int], Region] # f(row, column) -> Region
|
||||
|
||||
|
||||
class RendererLayout:
|
||||
VALID_ORIENTATIONS = ['h', 'v']
|
||||
|
||||
def __init__(self, cfg: LayoutConfig, nplots: int):
|
||||
self.cfg = cfg
|
||||
self.nplots = nplots
|
||||
|
||||
# Setup layout
|
||||
self.nrows, self.ncols = self._calc_layout()
|
||||
|
||||
self.orientation = cfg.orientation
|
||||
if self.orientation not in self.VALID_ORIENTATIONS:
|
||||
raise ValueError(f'Invalid orientation {self.orientation} not in '
|
||||
f'{self.VALID_ORIENTATIONS}')
|
||||
|
||||
def _calc_layout(self):
|
||||
"""
|
||||
Inputs: self.cfg, self.waves
|
||||
:return: (nrows, ncols)
|
||||
"""
|
||||
cfg = self.cfg
|
||||
|
||||
if cfg.nrows:
|
||||
nrows = cfg.nrows
|
||||
if nrows is None:
|
||||
raise ValueError('invalid cfg: rows_first is True and nrows is None')
|
||||
ncols = ceildiv(self.nplots, nrows)
|
||||
else:
|
||||
ncols = cfg.ncols
|
||||
if ncols is None:
|
||||
raise ValueError('invalid cfg: rows_first is False and ncols is None')
|
||||
nrows = ceildiv(self.nplots, ncols)
|
||||
|
||||
return nrows, ncols
|
||||
|
||||
def arrange(self, region_factory: RegionFactory) -> List[Region]:
|
||||
""" Generates an array of regions.
|
||||
|
||||
index, row, column are fed into region_factory in a row-major order [row][col].
|
||||
The results are possibly reshaped into column-major order [col][row].
|
||||
"""
|
||||
nspaces = self.nrows * self.ncols
|
||||
inds = np.arange(nspaces)
|
||||
rows, cols = np.unravel_index(inds, (self.nrows, self.ncols))
|
||||
|
||||
row_col = list(zip(rows, cols))
|
||||
regions = np.empty(len(row_col), dtype=object) # type: np.ndarray[Region]
|
||||
regions[:] = [region_factory(*rc) for rc in row_col]
|
||||
|
||||
regions2d = regions.reshape((self.nrows, self.ncols)) # type: np.ndarray[Region]
|
||||
|
||||
# if column major:
|
||||
if self.orientation == 'v':
|
||||
regions2d = regions2d.T
|
||||
|
||||
return regions2d.flatten()[:self.nplots].tolist()
|
|
@ -8,7 +8,8 @@ from typing import Optional, List, Union, TYPE_CHECKING
|
|||
from ovgenpy import outputs as outputs_
|
||||
from ovgenpy.channel import Channel, ChannelConfig
|
||||
from ovgenpy.config import register_config, register_enum, Ignored
|
||||
from ovgenpy.renderer import MatplotlibRenderer, RendererConfig, LayoutConfig
|
||||
from ovgenpy.renderer import MatplotlibRenderer, RendererConfig
|
||||
from ovgenpy.layout import LayoutConfig
|
||||
from ovgenpy.triggers import ITriggerConfig, CorrelationTriggerConfig, PerFrameCache
|
||||
from ovgenpy.util import pushd, coalesce
|
||||
from ovgenpy.utils import keyword_dataclasses as dc
|
||||
|
@ -136,8 +137,8 @@ class Ovgen:
|
|||
yield
|
||||
|
||||
def _load_renderer(self):
|
||||
renderer = MatplotlibRenderer(self.cfg.render, self.cfg.layout, self.nchan)
|
||||
renderer.set_colors(self.cfg.channels)
|
||||
renderer = MatplotlibRenderer(self.cfg.render, self.cfg.layout, self.nchan,
|
||||
self.cfg.channels)
|
||||
return renderer
|
||||
|
||||
def play(self):
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from typing import Optional, List, TYPE_CHECKING, TypeVar, Callable, Any
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, List, TYPE_CHECKING, Any
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
|
||||
from ovgenpy.config import register_config
|
||||
from ovgenpy.layout import RendererLayout, LayoutConfig
|
||||
from ovgenpy.outputs import RGB_DEPTH
|
||||
from ovgenpy.util import ceildiv, coalesce
|
||||
from ovgenpy.util import coalesce
|
||||
|
||||
matplotlib.use('agg')
|
||||
from matplotlib import pyplot as plt
|
||||
|
@ -38,7 +40,31 @@ class RendererConfig:
|
|||
create_window: bool = False
|
||||
|
||||
|
||||
class MatplotlibRenderer:
|
||||
class Renderer(ABC):
|
||||
def __init__(self, cfg: RendererConfig, lcfg: 'LayoutConfig', nplots: int,
|
||||
channel_cfgs: Optional[List['ChannelConfig']]):
|
||||
self.cfg = cfg
|
||||
self.nplots = nplots
|
||||
self.layout = RendererLayout(lcfg, nplots)
|
||||
|
||||
# Load line colors.
|
||||
if channel_cfgs is not None:
|
||||
if len(channel_cfgs) != self.nplots:
|
||||
raise ValueError(
|
||||
f"cannot assign {len(channel_cfgs)} colors to {self.nplots} plots"
|
||||
)
|
||||
self._line_colors = [cfg.line_color for cfg in channel_cfgs]
|
||||
else:
|
||||
self._line_colors = [None] * self.nplots
|
||||
|
||||
@abstractmethod
|
||||
def render_frame(self, datas: List[np.ndarray]) -> None: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_frame(self) -> bytes: ...
|
||||
|
||||
|
||||
class MatplotlibRenderer(Renderer):
|
||||
"""
|
||||
Renderer backend which takes data and produces images.
|
||||
Does not touch Wave or Channel.
|
||||
|
@ -61,18 +87,14 @@ class MatplotlibRenderer:
|
|||
|
||||
DPI = 96
|
||||
|
||||
def __init__(self, cfg: RendererConfig, lcfg: 'LayoutConfig', nplots: int):
|
||||
self.cfg = cfg
|
||||
self.nplots = nplots
|
||||
self.layout = RendererLayout(lcfg, nplots)
|
||||
def __init__(self, *args, **kwargs):
|
||||
Renderer.__init__(self, *args, **kwargs)
|
||||
|
||||
# Flat array of nrows*ncols elements, ordered by cfg.rows_first.
|
||||
self._fig: 'Figure' = None
|
||||
self._axes: List['Axes'] = None # set by set_layout()
|
||||
self._lines: List['Line2D'] = None # set by render_frame() first call
|
||||
|
||||
self._line_colors: List = [None] * nplots
|
||||
|
||||
self._set_layout() # mutates self
|
||||
|
||||
def _set_layout(self) -> None:
|
||||
|
@ -114,18 +136,6 @@ class MatplotlibRenderer:
|
|||
if self.cfg.create_window:
|
||||
plt.show(block=False)
|
||||
|
||||
def set_colors(self, channel_cfgs: List['ChannelConfig']):
|
||||
if len(channel_cfgs) != self.nplots:
|
||||
raise ValueError(
|
||||
f"cannot assign {len(channel_cfgs)} colors to {self.nplots} plots"
|
||||
)
|
||||
|
||||
if self._lines is not None:
|
||||
raise ValueError(
|
||||
f'cannot set line colors after calling render_frame()'
|
||||
)
|
||||
self._line_colors = [cfg.line_color for cfg in channel_cfgs]
|
||||
|
||||
def render_frame(self, datas: List[np.ndarray]) -> None:
|
||||
ndata = len(datas)
|
||||
if self.nplots != ndata:
|
||||
|
@ -179,81 +189,3 @@ class MatplotlibRenderer:
|
|||
|
||||
return buffer_rgb
|
||||
|
||||
|
||||
@register_config(always_dump='orientation')
|
||||
class LayoutConfig:
|
||||
orientation: str = 'h'
|
||||
nrows: Optional[int] = None
|
||||
ncols: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.nrows:
|
||||
self.nrows = None
|
||||
if not self.ncols:
|
||||
self.ncols = None
|
||||
|
||||
if self.nrows and self.ncols:
|
||||
raise ValueError('cannot manually assign both nrows and ncols')
|
||||
|
||||
if not self.nrows and not self.ncols:
|
||||
self.ncols = 1
|
||||
|
||||
|
||||
Region = TypeVar('Region')
|
||||
RegionFactory = Callable[[int, int], Region] # f(row, column) -> Region
|
||||
|
||||
|
||||
class RendererLayout:
|
||||
VALID_ORIENTATIONS = ['h', 'v']
|
||||
|
||||
def __init__(self, cfg: LayoutConfig, nplots: int):
|
||||
self.cfg = cfg
|
||||
self.nplots = nplots
|
||||
|
||||
# Setup layout
|
||||
self.nrows, self.ncols = self._calc_layout()
|
||||
|
||||
self.orientation = cfg.orientation
|
||||
if self.orientation not in self.VALID_ORIENTATIONS:
|
||||
raise ValueError(f'Invalid orientation {self.orientation} not in '
|
||||
f'{self.VALID_ORIENTATIONS}')
|
||||
|
||||
def _calc_layout(self):
|
||||
"""
|
||||
Inputs: self.cfg, self.waves
|
||||
:return: (nrows, ncols)
|
||||
"""
|
||||
cfg = self.cfg
|
||||
|
||||
if cfg.nrows:
|
||||
nrows = cfg.nrows
|
||||
if nrows is None:
|
||||
raise ValueError('invalid cfg: rows_first is True and nrows is None')
|
||||
ncols = ceildiv(self.nplots, nrows)
|
||||
else:
|
||||
ncols = cfg.ncols
|
||||
if ncols is None:
|
||||
raise ValueError('invalid cfg: rows_first is False and ncols is None')
|
||||
nrows = ceildiv(self.nplots, ncols)
|
||||
|
||||
return nrows, ncols
|
||||
|
||||
def arrange(self, region_factory: RegionFactory) -> List[Region]:
|
||||
""" Generates an array of regions.
|
||||
|
||||
index, row, column are fed into region_factory in a row-major order [row][col].
|
||||
The results are possibly reshaped into column-major order [col][row].
|
||||
"""
|
||||
nspaces = self.nrows * self.ncols
|
||||
inds = np.arange(nspaces)
|
||||
rows, cols = np.unravel_index(inds, (self.nrows, self.ncols))
|
||||
|
||||
row_col = np.array([rows, cols]).T
|
||||
regions = np.array([region_factory(*rc) for rc in row_col]) # type: np.ndarray[Region]
|
||||
regions2d = regions.reshape((self.nrows, self.ncols)) # type: np.ndarray[Region]
|
||||
|
||||
# if column major:
|
||||
if self.orientation == 'v':
|
||||
regions2d = regions2d.T
|
||||
|
||||
return regions2d.flatten()[:self.nplots].tolist()
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
|
||||
from ovgenpy.layout import LayoutConfig, RendererLayout
|
||||
from ovgenpy.renderer import RendererConfig, MatplotlibRenderer
|
||||
from tests.test_renderer import WIDTH, HEIGHT
|
||||
|
||||
|
||||
def test_layout_config():
|
||||
with pytest.raises(ValueError):
|
||||
LayoutConfig(nrows=1, ncols=1)
|
||||
|
||||
one_col = LayoutConfig(ncols=1)
|
||||
assert one_col
|
||||
|
||||
one_row = LayoutConfig(nrows=1)
|
||||
assert one_row
|
||||
|
||||
default = LayoutConfig()
|
||||
assert default.ncols == 1 # Should default to single-column layout
|
||||
assert default.nrows is None
|
||||
assert default.orientation == 'h'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('lcfg', [
|
||||
LayoutConfig(ncols=2),
|
||||
LayoutConfig(nrows=8),
|
||||
])
|
||||
@pytest.mark.parametrize('region_type', [str, tuple, list])
|
||||
def test_hlayout(lcfg, region_type):
|
||||
nplots = 15
|
||||
layout = RendererLayout(lcfg, nplots)
|
||||
|
||||
assert layout.ncols == 2
|
||||
assert layout.nrows == 8
|
||||
|
||||
regions = layout.arrange(lambda row, col: region_type((row, col)))
|
||||
assert len(regions) == nplots
|
||||
|
||||
assert regions[0] == region_type((0, 0))
|
||||
assert regions[1] == region_type((0, 1))
|
||||
assert regions[2] == region_type((1, 0))
|
||||
m = nplots - 1
|
||||
assert regions[m] == region_type((m // 2, m % 2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('lcfg', [
|
||||
LayoutConfig(ncols=3, orientation='v'),
|
||||
LayoutConfig(nrows=3, orientation='v'),
|
||||
])
|
||||
@pytest.mark.parametrize('region_type', [str, tuple, list])
|
||||
def test_vlayout(lcfg, region_type):
|
||||
nplots = 7
|
||||
layout = RendererLayout(lcfg, nplots)
|
||||
|
||||
assert layout.ncols == 3
|
||||
assert layout.nrows == 3
|
||||
|
||||
regions = layout.arrange(lambda row, col: region_type((row, col)))
|
||||
assert len(regions) == nplots
|
||||
|
||||
assert regions[0] == region_type((0, 0))
|
||||
assert regions[2] == region_type((2, 0))
|
||||
assert regions[3] == region_type((0, 1))
|
||||
assert regions[6] == region_type((0, 2))
|
||||
|
||||
|
||||
def test_renderer_layout():
|
||||
# 2 columns
|
||||
cfg = RendererConfig(WIDTH, HEIGHT)
|
||||
lcfg = LayoutConfig(ncols=2)
|
||||
nplots = 15
|
||||
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
|
||||
|
||||
# 2 columns, 8 rows
|
||||
assert r.layout.ncols == 2
|
||||
assert r.layout.nrows == 8
|
|
@ -19,7 +19,7 @@ NULL_CFG = FFmpegOutputConfig(None, '-f null')
|
|||
|
||||
def test_render_output():
|
||||
""" Ensure rendering to output does not raise exceptions. """
|
||||
renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1)
|
||||
renderer = MatplotlibRenderer(CFG.render, CFG.layout, nplots=1, channel_cfgs=None)
|
||||
out: FFmpegOutput = NULL_CFG(CFG)
|
||||
|
||||
renderer.render_frame([ALL_ZEROS])
|
||||
|
|
|
@ -4,86 +4,12 @@ from matplotlib.colors import to_rgb
|
|||
|
||||
from ovgenpy.channel import ChannelConfig
|
||||
from ovgenpy.outputs import RGB_DEPTH
|
||||
from ovgenpy.renderer import RendererConfig, MatplotlibRenderer, LayoutConfig, \
|
||||
RendererLayout
|
||||
from ovgenpy.renderer import RendererConfig, MatplotlibRenderer
|
||||
from ovgenpy.layout import LayoutConfig
|
||||
|
||||
WIDTH = 640
|
||||
HEIGHT = 360
|
||||
|
||||
|
||||
def test_config():
|
||||
with pytest.raises(ValueError):
|
||||
LayoutConfig(nrows=1, ncols=1)
|
||||
|
||||
one_col = LayoutConfig(ncols=1)
|
||||
assert one_col
|
||||
|
||||
one_row = LayoutConfig(nrows=1)
|
||||
assert one_row
|
||||
|
||||
default = LayoutConfig()
|
||||
assert default.ncols == 1 # Should default to single-column layout
|
||||
assert default.nrows is None
|
||||
assert default.orientation == 'h'
|
||||
|
||||
|
||||
@pytest.mark.parametrize('lcfg', [
|
||||
LayoutConfig(ncols=2),
|
||||
LayoutConfig(nrows=8),
|
||||
])
|
||||
def test_hlayout(lcfg):
|
||||
nplots = 15
|
||||
layout = RendererLayout(lcfg, nplots)
|
||||
|
||||
assert layout.ncols == 2
|
||||
assert layout.nrows == 8
|
||||
|
||||
# holy shit, passing tuples into a numpy array breaks things spectacularly, and it's
|
||||
# painfully difficult to stuff tuples into 1D array.
|
||||
# http://wesmckinney.com/blog/performance-quirk-making-a-1d-object-ndarray-of-tuples/
|
||||
regions = layout.arrange(lambda row, col: str((row, col)))
|
||||
assert len(regions) == nplots
|
||||
|
||||
assert regions[0] == '(0, 0)'
|
||||
assert regions[1] == '(0, 1)'
|
||||
assert regions[2] == '(1, 0)'
|
||||
m = nplots - 1
|
||||
assert regions[m] == str((m // 2, m % 2))
|
||||
|
||||
|
||||
@pytest.mark.parametrize('lcfg', [
|
||||
LayoutConfig(ncols=3, orientation='v'),
|
||||
LayoutConfig(nrows=3, orientation='v'),
|
||||
])
|
||||
def test_vlayout(lcfg):
|
||||
nplots = 7
|
||||
layout = RendererLayout(lcfg, nplots)
|
||||
|
||||
assert layout.ncols == 3
|
||||
assert layout.nrows == 3
|
||||
|
||||
regions = layout.arrange(lambda row, col: str((row, col)))
|
||||
assert len(regions) == nplots
|
||||
|
||||
assert regions[0] == '(0, 0)'
|
||||
assert regions[2] == '(2, 0)'
|
||||
assert regions[3] == '(0, 1)'
|
||||
assert regions[6] == '(0, 2)'
|
||||
|
||||
|
||||
def test_renderer():
|
||||
# 2 columns
|
||||
cfg = RendererConfig(WIDTH, HEIGHT)
|
||||
lcfg = LayoutConfig(ncols=2)
|
||||
nplots = 15
|
||||
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots)
|
||||
|
||||
# 2 columns, 8 rows
|
||||
assert r.layout.ncols == 2
|
||||
assert r.layout.nrows == 8
|
||||
|
||||
|
||||
ALL_ZEROS = np.array([0,0])
|
||||
|
||||
all_colors = pytest.mark.parametrize('bg_str,fg_str', [
|
||||
|
@ -106,13 +32,13 @@ def test_default_colors(bg_str, fg_str):
|
|||
lcfg = LayoutConfig()
|
||||
nplots = 1
|
||||
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots)
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots, None)
|
||||
verify(r, bg_str, fg_str)
|
||||
|
||||
# Ensure default ChannelConfig(line_color=None) does not override line color
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots)
|
||||
chan = ChannelConfig(wav_path='')
|
||||
r.set_colors([chan] * nplots)
|
||||
channels = [chan] * nplots
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots, channels)
|
||||
verify(r, bg_str, fg_str)
|
||||
|
||||
|
||||
|
@ -128,9 +54,9 @@ def test_line_colors(bg_str, fg_str):
|
|||
lcfg = LayoutConfig()
|
||||
nplots = 1
|
||||
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots)
|
||||
chan = ChannelConfig(wav_path='', line_color=fg_str)
|
||||
r.set_colors([chan] * nplots)
|
||||
channels = [chan] * nplots
|
||||
r = MatplotlibRenderer(cfg, lcfg, nplots, channels)
|
||||
verify(r, bg_str, fg_str)
|
||||
|
||||
|
||||
|
@ -146,7 +72,8 @@ def verify(r: MatplotlibRenderer, bg_str, fg_str):
|
|||
assert (frame_colors[0] == bg_u8).all()
|
||||
|
||||
# Ensure foreground is present
|
||||
assert np.prod(frame_colors == fg_u8, axis=-1).any()
|
||||
assert np.prod(frame_colors == fg_u8, axis=-1).any(), \
|
||||
'incorrect foreground, it might be 136 = #888888'
|
||||
|
||||
assert (np.amax(frame_colors, axis=0) == np.maximum(bg_u8, fg_u8)).all()
|
||||
assert (np.amin(frame_colors, axis=0) == np.minimum(bg_u8, fg_u8)).all()
|
||||
|
|
Ładowanie…
Reference in New Issue