kopia lustrzana https://github.com/corrscope/corrscope
Switch from keyword_dataclasses to attrs, fix Python 3.7
keyword_dataclasses works on 3.6 and 3.6 only. attrs works everywhere. Validators and converters are not used yet.pull/357/head
rodzic
8cb921cb14
commit
a78c3712c9
|
@ -1,13 +1,18 @@
|
|||
from io import StringIO
|
||||
from typing import ClassVar, TYPE_CHECKING
|
||||
|
||||
from ovgenpy.utils.keyword_dataclasses import dataclass, fields, Field, MISSING
|
||||
# from dataclasses import dataclass, fields
|
||||
import attr
|
||||
from ruamel.yaml import yaml_object, YAML, Representer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from enum import Enum
|
||||
|
||||
|
||||
__all__ = ['yaml',
|
||||
'register_config', 'kw_config', 'Alias', 'Ignored',
|
||||
'register_enum', 'OvgenError']
|
||||
|
||||
|
||||
# Setup YAML loading (yaml object).
|
||||
|
||||
class MyYAML(YAML):
|
||||
|
@ -31,12 +36,8 @@ _yaml_loadable = yaml_object(yaml)
|
|||
|
||||
# Setup configuration load/dump infrastructure.
|
||||
|
||||
def register_config(cls=None, *, always_dump: str = ''):
|
||||
""" Marks class as @dataclass, and enables YAML dumping (excludes default fields).
|
||||
|
||||
dataclasses.dataclass is compatible with yaml_object().
|
||||
typing.NamedTuple is incompatible.
|
||||
"""
|
||||
def register_config(cls=None, *, kw_only=False, always_dump: str = ''):
|
||||
""" Marks class as attrs, and enables YAML dumping (excludes default fields). """
|
||||
|
||||
def decorator(cls: type):
|
||||
cls.__getstate__ = _ConfigMixin.__getstate__
|
||||
|
@ -45,7 +46,7 @@ def register_config(cls=None, *, always_dump: str = ''):
|
|||
|
||||
# https://stackoverflow.com/a/51497219/2683842
|
||||
# YAML().register_class(cls) works... on versions more recent than 2018-07-12.
|
||||
return _yaml_loadable(dataclass(cls))
|
||||
return _yaml_loadable(attr.dataclass(cls, kw_only=kw_only))
|
||||
|
||||
if cls is not None:
|
||||
return decorator(cls)
|
||||
|
@ -53,7 +54,11 @@ def register_config(cls=None, *, always_dump: str = ''):
|
|||
return decorator
|
||||
|
||||
|
||||
@dataclass()
|
||||
def kw_config(*args, **kwargs):
|
||||
return register_config(*args, **kwargs, kw_only=True)
|
||||
|
||||
|
||||
@attr.dataclass()
|
||||
class _ConfigMixin:
|
||||
"""
|
||||
Class is unused. __getstate__ and __setstate__ are assigned into other classes.
|
||||
|
@ -73,7 +78,7 @@ class _ConfigMixin:
|
|||
state = {}
|
||||
cls = type(self)
|
||||
|
||||
for field in fields(self): # type: Field
|
||||
for field in attr.fields(cls):
|
||||
name = field.name
|
||||
value = getattr(self, name)
|
||||
|
||||
|
@ -83,8 +88,9 @@ class _ConfigMixin:
|
|||
|
||||
if field.default == value:
|
||||
continue
|
||||
if field.default_factory is not MISSING \
|
||||
and field.default_factory() == value:
|
||||
# noinspection PyTypeChecker,PyUnresolvedReferences
|
||||
if isinstance(field.default, attr.Factory) \
|
||||
and field.default.factory() == value:
|
||||
continue
|
||||
|
||||
state[name] = value
|
||||
|
@ -117,7 +123,7 @@ class _ConfigMixin:
|
|||
self.__dict__ = obj.__dict__
|
||||
|
||||
|
||||
@dataclass
|
||||
@attr.dataclass
|
||||
class Alias:
|
||||
"""
|
||||
@register_config
|
||||
|
@ -130,12 +136,6 @@ class Alias:
|
|||
|
||||
Ignored = object()
|
||||
|
||||
# Unused
|
||||
# def default(value):
|
||||
# """Supplies a mutable default value for a dataclass field."""
|
||||
# string = repr(value)
|
||||
# return field(default_factory=lambda: eval(string))
|
||||
|
||||
|
||||
# Setup Enum load/dump infrastructure
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ class LayoutConfig:
|
|||
nrows: Optional[int] = None
|
||||
ncols: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
def __attrs_post_init__(self):
|
||||
if not self.nrows:
|
||||
self.nrows = None
|
||||
if not self.ncols:
|
||||
|
|
|
@ -6,15 +6,15 @@ from fractions import Fraction
|
|||
from types import SimpleNamespace
|
||||
from typing import Optional, List, Union, TYPE_CHECKING
|
||||
|
||||
import attr
|
||||
|
||||
from ovgenpy import outputs as outputs_
|
||||
from ovgenpy.channel import Channel, ChannelConfig
|
||||
from ovgenpy.config import register_config, register_enum, Ignored
|
||||
from ovgenpy.config import kw_config, register_enum, Ignored
|
||||
from ovgenpy.renderer import MatplotlibRenderer, RendererConfig
|
||||
from ovgenpy.layout import LayoutConfig
|
||||
from ovgenpy.triggers import ITriggerConfig, CorrelationTriggerConfig, PerFrameCache
|
||||
from ovgenpy.util import pushd, coalesce
|
||||
from ovgenpy.utils import keyword_dataclasses as dc
|
||||
from ovgenpy.utils.keyword_dataclasses import field, InitVar
|
||||
from ovgenpy.wave import Wave
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -32,7 +32,7 @@ class BenchmarkMode(IntEnum):
|
|||
OUTPUT = 3
|
||||
|
||||
|
||||
@register_config(always_dump='render_subfps begin_time end_time subsampling')
|
||||
@kw_config(always_dump='render_subfps begin_time end_time subsampling')
|
||||
class Config:
|
||||
master_audio: Optional[str]
|
||||
begin_time: float = 0
|
||||
|
@ -49,7 +49,7 @@ class Config:
|
|||
# trigger_subsampling and render_subsampling override subsampling.
|
||||
trigger_subsampling: int = None
|
||||
render_subsampling: int = None
|
||||
subsampling: InitVar[int] = 1
|
||||
_subsampling: int = 1
|
||||
|
||||
trigger_width: int = 1
|
||||
render_width: int = 1
|
||||
|
@ -67,7 +67,7 @@ class Config:
|
|||
player: outputs_.IOutputConfig = outputs_.FFplayOutputConfig()
|
||||
encoder: outputs_.IOutputConfig = outputs_.FFmpegOutputConfig(None)
|
||||
|
||||
show_internals: List[str] = field(default_factory=list)
|
||||
show_internals: List[str] = attr.Factory(list)
|
||||
benchmark_mode: Union[str, BenchmarkMode] = BenchmarkMode.NONE
|
||||
|
||||
# region Legacy Fields
|
||||
|
@ -79,7 +79,7 @@ class Config:
|
|||
def width_s(self) -> float:
|
||||
return self.width_ms / 1000
|
||||
|
||||
def __post_init__(self, subsampling):
|
||||
def __attrs_post_init__(self):
|
||||
# Cast benchmark_mode to enum.
|
||||
try:
|
||||
if not isinstance(self.benchmark_mode, BenchmarkMode):
|
||||
|
@ -90,6 +90,7 @@ class Config:
|
|||
f'{[el.name for el in BenchmarkMode]}')
|
||||
|
||||
# Compute trigger_subsampling and render_subsampling.
|
||||
subsampling = self._subsampling
|
||||
self.trigger_subsampling = coalesce(self.trigger_subsampling, subsampling)
|
||||
self.render_subsampling = coalesce(self.render_subsampling, subsampling)
|
||||
|
||||
|
@ -118,7 +119,7 @@ def default_config(**kwargs) -> Config:
|
|||
layout=LayoutConfig(ncols=2),
|
||||
render=RendererConfig(1280, 800),
|
||||
)
|
||||
return dc.replace(cfg, **kwargs)
|
||||
return attr.evolve(cfg, **kwargs)
|
||||
|
||||
|
||||
class Ovgen:
|
||||
|
@ -189,9 +190,9 @@ class Ovgen:
|
|||
extra_outputs = SimpleNamespace()
|
||||
if internals:
|
||||
from ovgenpy.outputs import FFplayOutputConfig
|
||||
from ovgenpy.utils.keyword_dataclasses import replace
|
||||
import attr
|
||||
|
||||
no_audio = replace(self.cfg, master_audio='')
|
||||
no_audio = attr.evolve(self.cfg, master_audio='')
|
||||
|
||||
ovgen = self
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional, List, TYPE_CHECKING, Any
|
|||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
from ovgenpy.utils.keyword_dataclasses import dataclass
|
||||
import attr
|
||||
|
||||
from ovgenpy.config import register_config
|
||||
from ovgenpy.layout import RendererLayout, LayoutConfig
|
||||
|
@ -43,7 +43,7 @@ class RendererConfig:
|
|||
create_window: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
@attr.dataclass
|
||||
class LineParam:
|
||||
color: Any = None
|
||||
|
||||
|
|
|
@ -5,10 +5,10 @@ from typing import TYPE_CHECKING, Type, Tuple, Optional, ClassVar
|
|||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.signal import windows
|
||||
import attr
|
||||
|
||||
from ovgenpy.config import register_config, OvgenError, Alias
|
||||
from ovgenpy.config import kw_config, OvgenError, Alias
|
||||
from ovgenpy.util import find, obj_name
|
||||
from ovgenpy.utils.keyword_dataclasses import dataclass
|
||||
from ovgenpy.utils.windows import midpad, leftpad
|
||||
from ovgenpy.wave import FLOAT
|
||||
|
||||
|
@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
|||
|
||||
# Abstract classes
|
||||
|
||||
@dataclass
|
||||
@attr.dataclass
|
||||
class ITriggerConfig:
|
||||
cls: ClassVar[Type['Trigger']]
|
||||
|
||||
|
@ -82,7 +82,7 @@ class Trigger(ABC):
|
|||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
@attr.dataclass
|
||||
class PerFrameCache:
|
||||
"""
|
||||
The estimated period of a wave region (Wave.get_around())
|
||||
|
@ -101,7 +101,7 @@ class PerFrameCache:
|
|||
|
||||
# CorrelationTrigger
|
||||
|
||||
@register_config(always_dump='''
|
||||
@kw_config(always_dump='''
|
||||
use_edge_trigger
|
||||
edge_strength
|
||||
responsiveness
|
||||
|
@ -129,7 +129,7 @@ class CorrelationTriggerConfig(ITriggerConfig):
|
|||
use_edge_trigger: bool = True
|
||||
# endregion
|
||||
|
||||
def __post_init__(self):
|
||||
def __attrs_post_init__(self):
|
||||
self._validate_param('lag_prevention', 0, 1)
|
||||
self._validate_param('responsiveness', 0, 1)
|
||||
# TODO trigger_falloff >= 0
|
||||
|
@ -420,7 +420,7 @@ class PostTrigger(Trigger, ABC):
|
|||
|
||||
# Local edge-finding trigger
|
||||
|
||||
@register_config(always_dump='strength')
|
||||
@kw_config(always_dump='strength')
|
||||
class LocalPostTriggerConfig(ITriggerConfig):
|
||||
strength: float # Coefficient
|
||||
|
||||
|
@ -497,7 +497,7 @@ def seq_along(a: np.ndarray):
|
|||
|
||||
# ZeroCrossingTrigger
|
||||
|
||||
@register_config
|
||||
@kw_config
|
||||
class ZeroCrossingTriggerConfig(ITriggerConfig):
|
||||
pass
|
||||
|
||||
|
@ -551,7 +551,7 @@ class ZeroCrossingTrigger(PostTrigger):
|
|||
|
||||
# NullTrigger
|
||||
|
||||
@register_config
|
||||
@kw_config
|
||||
class NullTriggerConfig(ITriggerConfig):
|
||||
pass
|
||||
|
||||
|
|
Plik diff jest za duży
Load Diff
|
@ -1,12 +1,12 @@
|
|||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from ovgenpy.config import dataclass
|
||||
import attr
|
||||
from scipy.io import wavfile
|
||||
|
||||
|
||||
# Internal class, not exposed via YAML
|
||||
@dataclass
|
||||
@attr.dataclass
|
||||
class _WaveConfig:
|
||||
amplification: float = 1
|
||||
|
||||
|
|
3
setup.py
3
setup.py
|
@ -10,5 +10,6 @@ setup(
|
|||
author_email='',
|
||||
description='',
|
||||
tests_require=['pytest', 'pytest-pycharm', 'hypothesis', 'delayed-assert'],
|
||||
install_requires=['numpy', 'scipy', 'click', 'matplotlib', 'ruamel.yaml']
|
||||
install_requires=['numpy', 'scipy', 'click', 'matplotlib', 'ruamel.yaml',
|
||||
'attrs>=18.2.0']
|
||||
)
|
||||
|
|
|
@ -1,15 +1,12 @@
|
|||
# noinspection PyUnresolvedReferences
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from ruamel.yaml import yaml_object
|
||||
|
||||
from ovgenpy.config import register_config, yaml, Alias, Ignored
|
||||
from ovgenpy.config import register_config, yaml, Alias, Ignored, kw_config
|
||||
|
||||
# YAML Idiosyncrasies: https://docs.saltstack.com/en/develop/topics/troubleshooting/yaml_idiosyncrasies.html
|
||||
|
||||
# Load/dump infrastructure testing
|
||||
from ovgenpy.utils.keyword_dataclasses import fields
|
||||
import attr
|
||||
|
||||
|
||||
def test_register_config():
|
||||
|
@ -18,7 +15,7 @@ def test_register_config():
|
|||
foo: int
|
||||
bar: int
|
||||
|
||||
s = yaml.dump(Foo(1, 2))
|
||||
s = yaml.dump(Foo(foo=1, bar=2))
|
||||
assert s == '''\
|
||||
!Foo
|
||||
foo: 1
|
||||
|
@ -26,6 +23,17 @@ bar: 2
|
|||
'''
|
||||
|
||||
|
||||
def test_kw_config():
|
||||
@kw_config
|
||||
class Foo:
|
||||
foo: int = 1
|
||||
bar: int
|
||||
|
||||
obj = Foo(bar=2)
|
||||
assert obj.foo == 1
|
||||
assert obj.bar == 2
|
||||
|
||||
|
||||
def test_yaml_object():
|
||||
@yaml_object(yaml)
|
||||
class Bar:
|
||||
|
@ -75,6 +83,51 @@ b: b
|
|||
'''
|
||||
|
||||
|
||||
def test_dump_default_factory():
|
||||
""" Ensure default factories are not dumped, unless attribute present
|
||||
in `always_dump`.
|
||||
|
||||
Based on `attrs.Factory`. """
|
||||
|
||||
@register_config
|
||||
class Config:
|
||||
# Equivalent to attr.ib(factory=str)
|
||||
# See https://www.attrs.org/en/stable/types.html
|
||||
a: str = attr.Factory(str)
|
||||
b: str = attr.Factory(str)
|
||||
|
||||
s = yaml.dump(Config('alpha'))
|
||||
assert s == '''\
|
||||
!Config
|
||||
a: alpha
|
||||
'''
|
||||
|
||||
@register_config(always_dump='a b')
|
||||
class Config:
|
||||
a: str = attr.Factory(str)
|
||||
b: str = attr.Factory(str)
|
||||
c: str = attr.Factory(str)
|
||||
|
||||
s = yaml.dump(Config())
|
||||
assert s == '''\
|
||||
!Config
|
||||
a: ''
|
||||
b: ''
|
||||
'''
|
||||
|
||||
@register_config(always_dump='*')
|
||||
class Config:
|
||||
a: str = attr.Factory(str)
|
||||
b: str = attr.Factory(str)
|
||||
|
||||
s = yaml.dump(Config())
|
||||
assert s == '''\
|
||||
!Config
|
||||
a: ''
|
||||
b: ''
|
||||
'''
|
||||
|
||||
|
||||
# Dataclass load testing
|
||||
|
||||
|
||||
|
@ -82,13 +135,13 @@ def test_dump_load_aliases():
|
|||
""" Ensure dumping and loading `xx=Alias('x')` works.
|
||||
Ensure loading `{x=1, xx=1}` raises an error.
|
||||
Does not check constructor `Config(xx=1)`."""
|
||||
@register_config
|
||||
@register_config(kw_only=False)
|
||||
class Config:
|
||||
x: int
|
||||
xx = Alias('x')
|
||||
|
||||
# Test dumping
|
||||
assert len(fields(Config)) == 1
|
||||
assert len(attr.fields(Config)) == 1
|
||||
cfg = Config(1)
|
||||
s = yaml.dump(cfg)
|
||||
assert s == '''\
|
||||
|
@ -123,7 +176,7 @@ def test_dump_load_ignored():
|
|||
xx = Ignored
|
||||
|
||||
# Test dumping
|
||||
assert len(fields(Config)) == 0
|
||||
assert len(attr.fields(Config)) == 0
|
||||
cfg = Config()
|
||||
s = yaml.dump(cfg)
|
||||
assert s == '''\
|
||||
|
@ -162,13 +215,13 @@ def test_load_argument_validation():
|
|||
|
||||
|
||||
def test_load_post_init():
|
||||
""" yaml.load() does not natively call __post_init__. So @register_config modifies
|
||||
__setstate__ to call __post_init__. """
|
||||
""" yaml.load() does not natively call __init__.
|
||||
So @register_config modifies __setstate__ to call __attrs_post_init__. """
|
||||
@register_config
|
||||
class Foo:
|
||||
foo: int
|
||||
|
||||
def __post_init__(self):
|
||||
def __attrs_post_init__(self):
|
||||
self.foo = 99
|
||||
|
||||
s = '''\
|
||||
|
|
Ładowanie…
Reference in New Issue