From 65655b2645e05e16d19c96a34322092ea123c444 Mon Sep 17 00:00:00 2001 From: nyanpasu64 Date: Tue, 24 Jul 2018 04:28:53 -0700 Subject: [PATCH] [wip] Make config classes dumpable via YAML, rename base configs --- ovgenpy/config.py | 20 ++++++++++++++++++++ ovgenpy/outputs.py | 19 ++++++++++--------- ovgenpy/ovgenpy.py | 16 +++++++++------- ovgenpy/renderer.py | 11 +++++------ ovgenpy/triggers.py | 37 ++++++++++++++++++------------------- ovgenpy/wave.py | 7 +++++-- setup.py | 2 +- tests/test_config.py | 24 ++++++++++++++++++++++++ tests/test_trigger.py | 4 ++-- 9 files changed, 94 insertions(+), 46 deletions(-) create mode 100644 ovgenpy/config.py create mode 100644 tests/test_config.py diff --git a/ovgenpy/config.py b/ovgenpy/config.py new file mode 100644 index 0000000..0d39d1d --- /dev/null +++ b/ovgenpy/config.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from ruamel.yaml import YAML, yaml_object + + +# typing.NamedTuple is incompatible with yaml.register_class. +# dataclasses.dataclass is compatible. +# So use the latter. + +# __init__-less classes are also compatible with yaml.register_class. + + +yaml = YAML() + + +def register_dataclass(cls): + # https://stackoverflow.com/a/51497219/2683842 + # YAML.register_class(cls) has only returned cls since 2018-07-12. + return yaml_object(yaml)( + dataclass(cls) + ) diff --git a/ovgenpy/outputs.py b/ovgenpy/outputs.py index fbddd90..ec5180c 100644 --- a/ovgenpy/outputs.py +++ b/ovgenpy/outputs.py @@ -4,7 +4,7 @@ import subprocess from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Type, List, Union -from dataclasses import dataclass +from ovgenpy.config import register_dataclass if TYPE_CHECKING: import numpy as np @@ -14,7 +14,7 @@ if TYPE_CHECKING: RGB_DEPTH = 3 -class OutputConfig: +class IOutputConfig: cls: 'Type[Output]' def __call__(self, ovgen_cfg: 'Config'): @@ -22,7 +22,7 @@ class OutputConfig: class Output(ABC): - def __init__(self, ovgen_cfg: 'Config', cfg: OutputConfig): + def __init__(self, ovgen_cfg: 'Config', cfg: IOutputConfig): self.ovgen_cfg = ovgen_cfg self.cfg = cfg @@ -33,7 +33,7 @@ class Output(ABC): # Glue logic -def register_output(config_t: Type[OutputConfig]): +def register_output(config_t: Type[IOutputConfig]): def inner(output_t: Type[Output]): config_t.cls = output_t return output_t @@ -101,8 +101,8 @@ class ProcessOutput(Output): # FFmpegOutput -@dataclass -class FFmpegOutputConfig(OutputConfig): +@register_dataclass +class FFmpegOutputConfig(IOutputConfig): path: str video_template: str = '-c:v libx264 -crf 18 -bf 2 -flags +cgop -pix_fmt yuv420p -movflags faststart' audio_template: str = '-c:a aac -b:a 384k' @@ -121,7 +121,8 @@ class FFmpegOutput(ProcessOutput): # FFplayOutput -class FFplayOutputConfig(OutputConfig): +@register_dataclass +class FFplayOutputConfig(IOutputConfig): video_template: str = '-c:v copy' audio_template: str = '-c:a copy' @@ -154,8 +155,8 @@ class FFplayOutput(ProcessOutput): # ImageOutput -@dataclass -class ImageOutputConfig: +@register_dataclass +class ImageOutputConfig(IOutputConfig): path_prefix: str diff --git a/ovgenpy/ovgenpy.py b/ovgenpy/ovgenpy.py index 0af1436..039611b 100644 --- a/ovgenpy/ovgenpy.py +++ b/ovgenpy/ovgenpy.py @@ -2,32 +2,34 @@ import time from pathlib import Path -from typing import NamedTuple, Optional, List +from typing import Optional, List import click -from ovgenpy import outputs +from ovgenpy import outputs +from ovgenpy.config import register_dataclass from ovgenpy.renderer import MatplotlibRenderer, RendererConfig -from ovgenpy.triggers import TriggerConfig, CorrelationTrigger +from ovgenpy.triggers import ITriggerConfig, CorrelationTriggerConfig from ovgenpy.wave import WaveConfig, Wave RENDER_PROFILING = True -class Config(NamedTuple): +@register_dataclass +class Config: wave_dir: str audio_path: Optional[str] fps: int time_visible_ms: int scan_ratio: float - trigger: TriggerConfig # Maybe overriden per Wave + trigger: ITriggerConfig # Maybe overriden per Wave amplification: float render: RendererConfig - outputs: List[outputs.OutputConfig] + outputs: List[outputs.IOutputConfig] create_window: bool @property @@ -54,7 +56,7 @@ def main(wave_dir: str, audio_path: Optional[str], fps: int, output: str): time_visible_ms=25, scan_ratio=1, - trigger=CorrelationTrigger.Config( + trigger=CorrelationTriggerConfig( trigger_strength=1, use_edge_trigger=False, diff --git a/ovgenpy/renderer.py b/ovgenpy/renderer.py index 86f01d0..17afbec 100644 --- a/ovgenpy/renderer.py +++ b/ovgenpy/renderer.py @@ -2,23 +2,22 @@ from typing import Optional, List, Tuple, TYPE_CHECKING import matplotlib import numpy as np -from dataclasses import dataclass + +from ovgenpy.config import register_dataclass +from ovgenpy.outputs import RGB_DEPTH +from ovgenpy.util import ceildiv matplotlib.use('agg') from matplotlib import pyplot as plt from matplotlib.backends.backend_agg import FigureCanvasAgg -from ovgenpy.outputs import RGB_DEPTH -from ovgenpy.util import ceildiv - if TYPE_CHECKING: from matplotlib.axes import Axes from matplotlib.figure import Figure from matplotlib.lines import Line2D - -@dataclass +@register_dataclass class RendererConfig: width: int height: int diff --git a/ovgenpy/triggers.py b/ovgenpy/triggers.py index ee57b8a..3bcf0e6 100644 --- a/ovgenpy/triggers.py +++ b/ovgenpy/triggers.py @@ -2,9 +2,9 @@ from abc import ABC, abstractmethod from typing import TYPE_CHECKING import numpy as np -from dataclasses import dataclass from scipy import signal +from ovgenpy.config import register_dataclass from ovgenpy.util import find from ovgenpy.wave import FLOAT @@ -27,12 +27,11 @@ class Trigger(ABC): ... -class TriggerConfig: - # NamedTuple inheritance does not work. Mark children @dataclass instead. - # https://github.com/python/typing/issues/427 +class ITriggerConfig: def __call__(self, wave: 'Wave', scan_nsamp: int): # idea: __call__ return self.cls(wave, scan_nsamp, cfg=self) # problem: cannot reference XTrigger from within XTrigger + # solution: @register_trigger(XTriggerCfg) raise NotImplementedError @@ -40,26 +39,26 @@ def lerp(x: np.ndarray, y: np.ndarray, a: float): return x * (1 - a) + y * a +@register_dataclass +class CorrelationTriggerConfig(ITriggerConfig): + # get_trigger + trigger_strength: float + use_edge_trigger: bool + + # _update_buffer + responsiveness: float + falloff_width: float + + def __call__(self, wave: 'Wave', scan_nsamp: int): + return CorrelationTrigger(wave, scan_nsamp, cfg=self) + + class CorrelationTrigger(Trigger): MIN_AMPLITUDE = 0.01 - - @dataclass - class Config(TriggerConfig): - # get_trigger - trigger_strength: float - use_edge_trigger: bool - - # _update_buffer - responsiveness: float - falloff_width: float - - def __call__(self, wave: 'Wave', scan_nsamp: int): - return CorrelationTrigger(wave, scan_nsamp, cfg=self) - # get_trigger postprocessing: self._zero_trigger ZERO_CROSSING_SCAN = 256 - def __init__(self, wave: 'Wave', scan_nsamp: int, cfg: Config): + def __init__(self, wave: 'Wave', scan_nsamp: int, cfg: CorrelationTriggerConfig): """ Correlation-based trigger which looks at a window of `scan_nsamp` samples. diff --git a/ovgenpy/wave.py b/ovgenpy/wave.py index 6dc55ef..fc5e835 100644 --- a/ovgenpy/wave.py +++ b/ovgenpy/wave.py @@ -1,13 +1,16 @@ -from typing import NamedTuple, TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, Optional import numpy as np from scipy.io import wavfile +from ovgenpy.config import register_dataclass + if TYPE_CHECKING: from ovgenpy.triggers import Trigger -class WaveConfig(NamedTuple): +@register_dataclass +class WaveConfig: amplification: float = 1 diff --git a/setup.py b/setup.py index ca41821..6a86d23 100644 --- a/setup.py +++ b/setup.py @@ -10,5 +10,5 @@ setup( author_email='', description='', install_requires=['numpy', 'scipy', 'imageio', 'click', 'matplotlib', - 'dataclasses;python_version<"3.7"'] + 'dataclasses;python_version<"3.7"', 'ruamel.yaml'] ) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..47e0cc1 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,24 @@ +import sys + +from ruamel.yaml import yaml_object + +from ovgenpy.config import register_dataclass, yaml + + +def test_register_dataclass(): + @register_dataclass + class Foo: + foo: int + bar: int + + yaml.dump(Foo(1, 2), sys.stdout) + print() + + +def test_yaml_object(): + @yaml_object(yaml) + class Bar: + pass + + yaml.dump(Bar(), sys.stdout) + print() diff --git a/tests/test_trigger.py b/tests/test_trigger.py index c185246..9700285 100644 --- a/tests/test_trigger.py +++ b/tests/test_trigger.py @@ -4,7 +4,7 @@ from matplotlib.axes import Axes from matplotlib.figure import Figure from ovgenpy import triggers -from ovgenpy.triggers import CorrelationTrigger +from ovgenpy.triggers import CorrelationTriggerConfig from ovgenpy.wave import Wave @@ -14,7 +14,7 @@ triggers.SHOW_TRIGGER = False @pytest.fixture(scope='session', params=[False, True]) def cfg(request): use_edge_trigger = request.param - return CorrelationTrigger.Config( + return CorrelationTriggerConfig( trigger_strength=1, use_edge_trigger=use_edge_trigger,