kopia lustrzana https://github.com/corrscope/corrscope
[wip] Make config classes dumpable via YAML, rename base configs
rodzic
e1dfdd13da
commit
65655b2645
|
@ -0,0 +1,20 @@
|
|||
from dataclasses import dataclass
|
||||
from ruamel.yaml import YAML, yaml_object
|
||||
|
||||
|
||||
# typing.NamedTuple is incompatible with yaml.register_class.
|
||||
# dataclasses.dataclass is compatible.
|
||||
# So use the latter.
|
||||
|
||||
# __init__-less classes are also compatible with yaml.register_class.
|
||||
|
||||
|
||||
yaml = YAML()
|
||||
|
||||
|
||||
def register_dataclass(cls):
|
||||
# https://stackoverflow.com/a/51497219/2683842
|
||||
# YAML.register_class(cls) has only returned cls since 2018-07-12.
|
||||
return yaml_object(yaml)(
|
||||
dataclass(cls)
|
||||
)
|
|
@ -4,7 +4,7 @@ import subprocess
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Type, List, Union
|
||||
|
||||
from dataclasses import dataclass
|
||||
from ovgenpy.config import register_dataclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
@ -14,7 +14,7 @@ if TYPE_CHECKING:
|
|||
RGB_DEPTH = 3
|
||||
|
||||
|
||||
class OutputConfig:
|
||||
class IOutputConfig:
|
||||
cls: 'Type[Output]'
|
||||
|
||||
def __call__(self, ovgen_cfg: 'Config'):
|
||||
|
@ -22,7 +22,7 @@ class OutputConfig:
|
|||
|
||||
|
||||
class Output(ABC):
|
||||
def __init__(self, ovgen_cfg: 'Config', cfg: OutputConfig):
|
||||
def __init__(self, ovgen_cfg: 'Config', cfg: IOutputConfig):
|
||||
self.ovgen_cfg = ovgen_cfg
|
||||
self.cfg = cfg
|
||||
|
||||
|
@ -33,7 +33,7 @@ class Output(ABC):
|
|||
|
||||
# Glue logic
|
||||
|
||||
def register_output(config_t: Type[OutputConfig]):
|
||||
def register_output(config_t: Type[IOutputConfig]):
|
||||
def inner(output_t: Type[Output]):
|
||||
config_t.cls = output_t
|
||||
return output_t
|
||||
|
@ -101,8 +101,8 @@ class ProcessOutput(Output):
|
|||
|
||||
|
||||
# FFmpegOutput
|
||||
@dataclass
|
||||
class FFmpegOutputConfig(OutputConfig):
|
||||
@register_dataclass
|
||||
class FFmpegOutputConfig(IOutputConfig):
|
||||
path: str
|
||||
video_template: str = '-c:v libx264 -crf 18 -bf 2 -flags +cgop -pix_fmt yuv420p -movflags faststart'
|
||||
audio_template: str = '-c:a aac -b:a 384k'
|
||||
|
@ -121,7 +121,8 @@ class FFmpegOutput(ProcessOutput):
|
|||
|
||||
|
||||
# FFplayOutput
|
||||
class FFplayOutputConfig(OutputConfig):
|
||||
@register_dataclass
|
||||
class FFplayOutputConfig(IOutputConfig):
|
||||
video_template: str = '-c:v copy'
|
||||
audio_template: str = '-c:a copy'
|
||||
|
||||
|
@ -154,8 +155,8 @@ class FFplayOutput(ProcessOutput):
|
|||
|
||||
|
||||
# ImageOutput
|
||||
@dataclass
|
||||
class ImageOutputConfig:
|
||||
@register_dataclass
|
||||
class ImageOutputConfig(IOutputConfig):
|
||||
path_prefix: str
|
||||
|
||||
|
||||
|
|
|
@ -2,32 +2,34 @@
|
|||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple, Optional, List
|
||||
from typing import Optional, List
|
||||
|
||||
import click
|
||||
from ovgenpy import outputs
|
||||
|
||||
from ovgenpy import outputs
|
||||
from ovgenpy.config import register_dataclass
|
||||
from ovgenpy.renderer import MatplotlibRenderer, RendererConfig
|
||||
from ovgenpy.triggers import TriggerConfig, CorrelationTrigger
|
||||
from ovgenpy.triggers import ITriggerConfig, CorrelationTriggerConfig
|
||||
from ovgenpy.wave import WaveConfig, Wave
|
||||
|
||||
|
||||
RENDER_PROFILING = True
|
||||
|
||||
|
||||
class Config(NamedTuple):
|
||||
@register_dataclass
|
||||
class Config:
|
||||
wave_dir: str
|
||||
audio_path: Optional[str]
|
||||
fps: int
|
||||
|
||||
time_visible_ms: int
|
||||
scan_ratio: float
|
||||
trigger: TriggerConfig # Maybe overriden per Wave
|
||||
trigger: ITriggerConfig # Maybe overriden per Wave
|
||||
|
||||
amplification: float
|
||||
render: RendererConfig
|
||||
|
||||
outputs: List[outputs.OutputConfig]
|
||||
outputs: List[outputs.IOutputConfig]
|
||||
create_window: bool
|
||||
|
||||
@property
|
||||
|
@ -54,7 +56,7 @@ def main(wave_dir: str, audio_path: Optional[str], fps: int, output: str):
|
|||
|
||||
time_visible_ms=25,
|
||||
scan_ratio=1,
|
||||
trigger=CorrelationTrigger.Config(
|
||||
trigger=CorrelationTriggerConfig(
|
||||
trigger_strength=1,
|
||||
use_edge_trigger=False,
|
||||
|
||||
|
|
|
@ -2,23 +2,22 @@ from typing import Optional, List, Tuple, TYPE_CHECKING
|
|||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ovgenpy.config import register_dataclass
|
||||
from ovgenpy.outputs import RGB_DEPTH
|
||||
from ovgenpy.util import ceildiv
|
||||
|
||||
matplotlib.use('agg')
|
||||
from matplotlib import pyplot as plt
|
||||
from matplotlib.backends.backend_agg import FigureCanvasAgg
|
||||
|
||||
from ovgenpy.outputs import RGB_DEPTH
|
||||
from ovgenpy.util import ceildiv
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from matplotlib.axes import Axes
|
||||
from matplotlib.figure import Figure
|
||||
from matplotlib.lines import Line2D
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
@register_dataclass
|
||||
class RendererConfig:
|
||||
width: int
|
||||
height: int
|
||||
|
|
|
@ -2,9 +2,9 @@ from abc import ABC, abstractmethod
|
|||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from scipy import signal
|
||||
|
||||
from ovgenpy.config import register_dataclass
|
||||
from ovgenpy.util import find
|
||||
from ovgenpy.wave import FLOAT
|
||||
|
||||
|
@ -27,12 +27,11 @@ class Trigger(ABC):
|
|||
...
|
||||
|
||||
|
||||
class TriggerConfig:
|
||||
# NamedTuple inheritance does not work. Mark children @dataclass instead.
|
||||
# https://github.com/python/typing/issues/427
|
||||
class ITriggerConfig:
|
||||
def __call__(self, wave: 'Wave', scan_nsamp: int):
|
||||
# idea: __call__ return self.cls(wave, scan_nsamp, cfg=self)
|
||||
# problem: cannot reference XTrigger from within XTrigger
|
||||
# solution: @register_trigger(XTriggerCfg)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
@ -40,26 +39,26 @@ def lerp(x: np.ndarray, y: np.ndarray, a: float):
|
|||
return x * (1 - a) + y * a
|
||||
|
||||
|
||||
@register_dataclass
|
||||
class CorrelationTriggerConfig(ITriggerConfig):
|
||||
# get_trigger
|
||||
trigger_strength: float
|
||||
use_edge_trigger: bool
|
||||
|
||||
# _update_buffer
|
||||
responsiveness: float
|
||||
falloff_width: float
|
||||
|
||||
def __call__(self, wave: 'Wave', scan_nsamp: int):
|
||||
return CorrelationTrigger(wave, scan_nsamp, cfg=self)
|
||||
|
||||
|
||||
class CorrelationTrigger(Trigger):
|
||||
MIN_AMPLITUDE = 0.01
|
||||
|
||||
@dataclass
|
||||
class Config(TriggerConfig):
|
||||
# get_trigger
|
||||
trigger_strength: float
|
||||
use_edge_trigger: bool
|
||||
|
||||
# _update_buffer
|
||||
responsiveness: float
|
||||
falloff_width: float
|
||||
|
||||
def __call__(self, wave: 'Wave', scan_nsamp: int):
|
||||
return CorrelationTrigger(wave, scan_nsamp, cfg=self)
|
||||
|
||||
# get_trigger postprocessing: self._zero_trigger
|
||||
ZERO_CROSSING_SCAN = 256
|
||||
|
||||
def __init__(self, wave: 'Wave', scan_nsamp: int, cfg: Config):
|
||||
def __init__(self, wave: 'Wave', scan_nsamp: int, cfg: CorrelationTriggerConfig):
|
||||
"""
|
||||
Correlation-based trigger which looks at a window of `scan_nsamp` samples.
|
||||
|
||||
|
|
|
@ -1,13 +1,16 @@
|
|||
from typing import NamedTuple, TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
from ovgenpy.config import register_dataclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ovgenpy.triggers import Trigger
|
||||
|
||||
|
||||
class WaveConfig(NamedTuple):
|
||||
@register_dataclass
|
||||
class WaveConfig:
|
||||
amplification: float = 1
|
||||
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -10,5 +10,5 @@ setup(
|
|||
author_email='',
|
||||
description='',
|
||||
install_requires=['numpy', 'scipy', 'imageio', 'click', 'matplotlib',
|
||||
'dataclasses;python_version<"3.7"']
|
||||
'dataclasses;python_version<"3.7"', 'ruamel.yaml']
|
||||
)
|
||||
|
|
|
@ -0,0 +1,24 @@
|
|||
import sys
|
||||
|
||||
from ruamel.yaml import yaml_object
|
||||
|
||||
from ovgenpy.config import register_dataclass, yaml
|
||||
|
||||
|
||||
def test_register_dataclass():
|
||||
@register_dataclass
|
||||
class Foo:
|
||||
foo: int
|
||||
bar: int
|
||||
|
||||
yaml.dump(Foo(1, 2), sys.stdout)
|
||||
print()
|
||||
|
||||
|
||||
def test_yaml_object():
|
||||
@yaml_object(yaml)
|
||||
class Bar:
|
||||
pass
|
||||
|
||||
yaml.dump(Bar(), sys.stdout)
|
||||
print()
|
|
@ -4,7 +4,7 @@ from matplotlib.axes import Axes
|
|||
from matplotlib.figure import Figure
|
||||
|
||||
from ovgenpy import triggers
|
||||
from ovgenpy.triggers import CorrelationTrigger
|
||||
from ovgenpy.triggers import CorrelationTriggerConfig
|
||||
from ovgenpy.wave import Wave
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@ triggers.SHOW_TRIGGER = False
|
|||
@pytest.fixture(scope='session', params=[False, True])
|
||||
def cfg(request):
|
||||
use_edge_trigger = request.param
|
||||
return CorrelationTrigger.Config(
|
||||
return CorrelationTriggerConfig(
|
||||
trigger_strength=1,
|
||||
use_edge_trigger=use_edge_trigger,
|
||||
|
||||
|
|
Ładowanie…
Reference in New Issue