Switch from keyword_dataclasses to attrs, fix Python 3.7

keyword_dataclasses works on 3.6 and 3.6 only. attrs works everywhere.

Validators and converters are not used yet.
pull/357/head
nyanpasu64 2018-11-25 05:55:55 -08:00
rodzic 8cb921cb14
commit a78c3712c9
9 zmienionych plików z 112 dodań i 1239 usunięć

Wyświetl plik

@ -1,13 +1,18 @@
from io import StringIO from io import StringIO
from typing import ClassVar, TYPE_CHECKING from typing import ClassVar, TYPE_CHECKING
from ovgenpy.utils.keyword_dataclasses import dataclass, fields, Field, MISSING import attr
# from dataclasses import dataclass, fields
from ruamel.yaml import yaml_object, YAML, Representer from ruamel.yaml import yaml_object, YAML, Representer
if TYPE_CHECKING: if TYPE_CHECKING:
from enum import Enum from enum import Enum
__all__ = ['yaml',
'register_config', 'kw_config', 'Alias', 'Ignored',
'register_enum', 'OvgenError']
# Setup YAML loading (yaml object). # Setup YAML loading (yaml object).
class MyYAML(YAML): class MyYAML(YAML):
@ -31,12 +36,8 @@ _yaml_loadable = yaml_object(yaml)
# Setup configuration load/dump infrastructure. # Setup configuration load/dump infrastructure.
def register_config(cls=None, *, always_dump: str = ''): def register_config(cls=None, *, kw_only=False, always_dump: str = ''):
""" Marks class as @dataclass, and enables YAML dumping (excludes default fields). """ Marks class as attrs, and enables YAML dumping (excludes default fields). """
dataclasses.dataclass is compatible with yaml_object().
typing.NamedTuple is incompatible.
"""
def decorator(cls: type): def decorator(cls: type):
cls.__getstate__ = _ConfigMixin.__getstate__ cls.__getstate__ = _ConfigMixin.__getstate__
@ -45,7 +46,7 @@ def register_config(cls=None, *, always_dump: str = ''):
# https://stackoverflow.com/a/51497219/2683842 # https://stackoverflow.com/a/51497219/2683842
# YAML().register_class(cls) works... on versions more recent than 2018-07-12. # YAML().register_class(cls) works... on versions more recent than 2018-07-12.
return _yaml_loadable(dataclass(cls)) return _yaml_loadable(attr.dataclass(cls, kw_only=kw_only))
if cls is not None: if cls is not None:
return decorator(cls) return decorator(cls)
@ -53,7 +54,11 @@ def register_config(cls=None, *, always_dump: str = ''):
return decorator return decorator
@dataclass() def kw_config(*args, **kwargs):
return register_config(*args, **kwargs, kw_only=True)
@attr.dataclass()
class _ConfigMixin: class _ConfigMixin:
""" """
Class is unused. __getstate__ and __setstate__ are assigned into other classes. Class is unused. __getstate__ and __setstate__ are assigned into other classes.
@ -73,7 +78,7 @@ class _ConfigMixin:
state = {} state = {}
cls = type(self) cls = type(self)
for field in fields(self): # type: Field for field in attr.fields(cls):
name = field.name name = field.name
value = getattr(self, name) value = getattr(self, name)
@ -83,8 +88,9 @@ class _ConfigMixin:
if field.default == value: if field.default == value:
continue continue
if field.default_factory is not MISSING \ # noinspection PyTypeChecker,PyUnresolvedReferences
and field.default_factory() == value: if isinstance(field.default, attr.Factory) \
and field.default.factory() == value:
continue continue
state[name] = value state[name] = value
@ -117,7 +123,7 @@ class _ConfigMixin:
self.__dict__ = obj.__dict__ self.__dict__ = obj.__dict__
@dataclass @attr.dataclass
class Alias: class Alias:
""" """
@register_config @register_config
@ -130,12 +136,6 @@ class Alias:
Ignored = object() Ignored = object()
# Unused
# def default(value):
# """Supplies a mutable default value for a dataclass field."""
# string = repr(value)
# return field(default_factory=lambda: eval(string))
# Setup Enum load/dump infrastructure # Setup Enum load/dump infrastructure

Wyświetl plik

@ -12,7 +12,7 @@ class LayoutConfig:
nrows: Optional[int] = None nrows: Optional[int] = None
ncols: Optional[int] = None ncols: Optional[int] = None
def __post_init__(self): def __attrs_post_init__(self):
if not self.nrows: if not self.nrows:
self.nrows = None self.nrows = None
if not self.ncols: if not self.ncols:

Wyświetl plik

@ -6,15 +6,15 @@ from fractions import Fraction
from types import SimpleNamespace from types import SimpleNamespace
from typing import Optional, List, Union, TYPE_CHECKING from typing import Optional, List, Union, TYPE_CHECKING
import attr
from ovgenpy import outputs as outputs_ from ovgenpy import outputs as outputs_
from ovgenpy.channel import Channel, ChannelConfig from ovgenpy.channel import Channel, ChannelConfig
from ovgenpy.config import register_config, register_enum, Ignored from ovgenpy.config import kw_config, register_enum, Ignored
from ovgenpy.renderer import MatplotlibRenderer, RendererConfig from ovgenpy.renderer import MatplotlibRenderer, RendererConfig
from ovgenpy.layout import LayoutConfig from ovgenpy.layout import LayoutConfig
from ovgenpy.triggers import ITriggerConfig, CorrelationTriggerConfig, PerFrameCache from ovgenpy.triggers import ITriggerConfig, CorrelationTriggerConfig, PerFrameCache
from ovgenpy.util import pushd, coalesce from ovgenpy.util import pushd, coalesce
from ovgenpy.utils import keyword_dataclasses as dc
from ovgenpy.utils.keyword_dataclasses import field, InitVar
from ovgenpy.wave import Wave from ovgenpy.wave import Wave
if TYPE_CHECKING: if TYPE_CHECKING:
@ -32,7 +32,7 @@ class BenchmarkMode(IntEnum):
OUTPUT = 3 OUTPUT = 3
@register_config(always_dump='render_subfps begin_time end_time subsampling') @kw_config(always_dump='render_subfps begin_time end_time subsampling')
class Config: class Config:
master_audio: Optional[str] master_audio: Optional[str]
begin_time: float = 0 begin_time: float = 0
@ -49,7 +49,7 @@ class Config:
# trigger_subsampling and render_subsampling override subsampling. # trigger_subsampling and render_subsampling override subsampling.
trigger_subsampling: int = None trigger_subsampling: int = None
render_subsampling: int = None render_subsampling: int = None
subsampling: InitVar[int] = 1 _subsampling: int = 1
trigger_width: int = 1 trigger_width: int = 1
render_width: int = 1 render_width: int = 1
@ -67,7 +67,7 @@ class Config:
player: outputs_.IOutputConfig = outputs_.FFplayOutputConfig() player: outputs_.IOutputConfig = outputs_.FFplayOutputConfig()
encoder: outputs_.IOutputConfig = outputs_.FFmpegOutputConfig(None) encoder: outputs_.IOutputConfig = outputs_.FFmpegOutputConfig(None)
show_internals: List[str] = field(default_factory=list) show_internals: List[str] = attr.Factory(list)
benchmark_mode: Union[str, BenchmarkMode] = BenchmarkMode.NONE benchmark_mode: Union[str, BenchmarkMode] = BenchmarkMode.NONE
# region Legacy Fields # region Legacy Fields
@ -79,7 +79,7 @@ class Config:
def width_s(self) -> float: def width_s(self) -> float:
return self.width_ms / 1000 return self.width_ms / 1000
def __post_init__(self, subsampling): def __attrs_post_init__(self):
# Cast benchmark_mode to enum. # Cast benchmark_mode to enum.
try: try:
if not isinstance(self.benchmark_mode, BenchmarkMode): if not isinstance(self.benchmark_mode, BenchmarkMode):
@ -90,6 +90,7 @@ class Config:
f'{[el.name for el in BenchmarkMode]}') f'{[el.name for el in BenchmarkMode]}')
# Compute trigger_subsampling and render_subsampling. # Compute trigger_subsampling and render_subsampling.
subsampling = self._subsampling
self.trigger_subsampling = coalesce(self.trigger_subsampling, subsampling) self.trigger_subsampling = coalesce(self.trigger_subsampling, subsampling)
self.render_subsampling = coalesce(self.render_subsampling, subsampling) self.render_subsampling = coalesce(self.render_subsampling, subsampling)
@ -118,7 +119,7 @@ def default_config(**kwargs) -> Config:
layout=LayoutConfig(ncols=2), layout=LayoutConfig(ncols=2),
render=RendererConfig(1280, 800), render=RendererConfig(1280, 800),
) )
return dc.replace(cfg, **kwargs) return attr.evolve(cfg, **kwargs)
class Ovgen: class Ovgen:
@ -189,9 +190,9 @@ class Ovgen:
extra_outputs = SimpleNamespace() extra_outputs = SimpleNamespace()
if internals: if internals:
from ovgenpy.outputs import FFplayOutputConfig from ovgenpy.outputs import FFplayOutputConfig
from ovgenpy.utils.keyword_dataclasses import replace import attr
no_audio = replace(self.cfg, master_audio='') no_audio = attr.evolve(self.cfg, master_audio='')
ovgen = self ovgen = self

Wyświetl plik

@ -3,7 +3,7 @@ from typing import Optional, List, TYPE_CHECKING, Any
import matplotlib import matplotlib
import numpy as np import numpy as np
from ovgenpy.utils.keyword_dataclasses import dataclass import attr
from ovgenpy.config import register_config from ovgenpy.config import register_config
from ovgenpy.layout import RendererLayout, LayoutConfig from ovgenpy.layout import RendererLayout, LayoutConfig
@ -43,7 +43,7 @@ class RendererConfig:
create_window: bool = False create_window: bool = False
@dataclass @attr.dataclass
class LineParam: class LineParam:
color: Any = None color: Any = None

Wyświetl plik

@ -5,10 +5,10 @@ from typing import TYPE_CHECKING, Type, Tuple, Optional, ClassVar
import numpy as np import numpy as np
from scipy import signal from scipy import signal
from scipy.signal import windows from scipy.signal import windows
import attr
from ovgenpy.config import register_config, OvgenError, Alias from ovgenpy.config import kw_config, OvgenError, Alias
from ovgenpy.util import find, obj_name from ovgenpy.util import find, obj_name
from ovgenpy.utils.keyword_dataclasses import dataclass
from ovgenpy.utils.windows import midpad, leftpad from ovgenpy.utils.windows import midpad, leftpad
from ovgenpy.wave import FLOAT from ovgenpy.wave import FLOAT
@ -18,7 +18,7 @@ if TYPE_CHECKING:
# Abstract classes # Abstract classes
@dataclass @attr.dataclass
class ITriggerConfig: class ITriggerConfig:
cls: ClassVar[Type['Trigger']] cls: ClassVar[Type['Trigger']]
@ -82,7 +82,7 @@ class Trigger(ABC):
... ...
@dataclass @attr.dataclass
class PerFrameCache: class PerFrameCache:
""" """
The estimated period of a wave region (Wave.get_around()) The estimated period of a wave region (Wave.get_around())
@ -101,7 +101,7 @@ class PerFrameCache:
# CorrelationTrigger # CorrelationTrigger
@register_config(always_dump=''' @kw_config(always_dump='''
use_edge_trigger use_edge_trigger
edge_strength edge_strength
responsiveness responsiveness
@ -129,7 +129,7 @@ class CorrelationTriggerConfig(ITriggerConfig):
use_edge_trigger: bool = True use_edge_trigger: bool = True
# endregion # endregion
def __post_init__(self): def __attrs_post_init__(self):
self._validate_param('lag_prevention', 0, 1) self._validate_param('lag_prevention', 0, 1)
self._validate_param('responsiveness', 0, 1) self._validate_param('responsiveness', 0, 1)
# TODO trigger_falloff >= 0 # TODO trigger_falloff >= 0
@ -420,7 +420,7 @@ class PostTrigger(Trigger, ABC):
# Local edge-finding trigger # Local edge-finding trigger
@register_config(always_dump='strength') @kw_config(always_dump='strength')
class LocalPostTriggerConfig(ITriggerConfig): class LocalPostTriggerConfig(ITriggerConfig):
strength: float # Coefficient strength: float # Coefficient
@ -497,7 +497,7 @@ def seq_along(a: np.ndarray):
# ZeroCrossingTrigger # ZeroCrossingTrigger
@register_config @kw_config
class ZeroCrossingTriggerConfig(ITriggerConfig): class ZeroCrossingTriggerConfig(ITriggerConfig):
pass pass
@ -551,7 +551,7 @@ class ZeroCrossingTrigger(PostTrigger):
# NullTrigger # NullTrigger
@register_config @kw_config
class NullTriggerConfig(ITriggerConfig): class NullTriggerConfig(ITriggerConfig):
pass pass

Wyświetl plik

@ -1,12 +1,12 @@
from typing import Optional from typing import Optional
import numpy as np import numpy as np
from ovgenpy.config import dataclass import attr
from scipy.io import wavfile from scipy.io import wavfile
# Internal class, not exposed via YAML # Internal class, not exposed via YAML
@dataclass @attr.dataclass
class _WaveConfig: class _WaveConfig:
amplification: float = 1 amplification: float = 1

Wyświetl plik

@ -10,5 +10,6 @@ setup(
author_email='', author_email='',
description='', description='',
tests_require=['pytest', 'pytest-pycharm', 'hypothesis', 'delayed-assert'], tests_require=['pytest', 'pytest-pycharm', 'hypothesis', 'delayed-assert'],
install_requires=['numpy', 'scipy', 'click', 'matplotlib', 'ruamel.yaml'] install_requires=['numpy', 'scipy', 'click', 'matplotlib', 'ruamel.yaml',
'attrs>=18.2.0']
) )

Wyświetl plik

@ -1,15 +1,12 @@
# noinspection PyUnresolvedReferences
import sys
import pytest import pytest
from ruamel.yaml import yaml_object from ruamel.yaml import yaml_object
from ovgenpy.config import register_config, yaml, Alias, Ignored from ovgenpy.config import register_config, yaml, Alias, Ignored, kw_config
# YAML Idiosyncrasies: https://docs.saltstack.com/en/develop/topics/troubleshooting/yaml_idiosyncrasies.html # YAML Idiosyncrasies: https://docs.saltstack.com/en/develop/topics/troubleshooting/yaml_idiosyncrasies.html
# Load/dump infrastructure testing # Load/dump infrastructure testing
from ovgenpy.utils.keyword_dataclasses import fields import attr
def test_register_config(): def test_register_config():
@ -18,7 +15,7 @@ def test_register_config():
foo: int foo: int
bar: int bar: int
s = yaml.dump(Foo(1, 2)) s = yaml.dump(Foo(foo=1, bar=2))
assert s == '''\ assert s == '''\
!Foo !Foo
foo: 1 foo: 1
@ -26,6 +23,17 @@ bar: 2
''' '''
def test_kw_config():
@kw_config
class Foo:
foo: int = 1
bar: int
obj = Foo(bar=2)
assert obj.foo == 1
assert obj.bar == 2
def test_yaml_object(): def test_yaml_object():
@yaml_object(yaml) @yaml_object(yaml)
class Bar: class Bar:
@ -75,6 +83,51 @@ b: b
''' '''
def test_dump_default_factory():
""" Ensure default factories are not dumped, unless attribute present
in `always_dump`.
Based on `attrs.Factory`. """
@register_config
class Config:
# Equivalent to attr.ib(factory=str)
# See https://www.attrs.org/en/stable/types.html
a: str = attr.Factory(str)
b: str = attr.Factory(str)
s = yaml.dump(Config('alpha'))
assert s == '''\
!Config
a: alpha
'''
@register_config(always_dump='a b')
class Config:
a: str = attr.Factory(str)
b: str = attr.Factory(str)
c: str = attr.Factory(str)
s = yaml.dump(Config())
assert s == '''\
!Config
a: ''
b: ''
'''
@register_config(always_dump='*')
class Config:
a: str = attr.Factory(str)
b: str = attr.Factory(str)
s = yaml.dump(Config())
assert s == '''\
!Config
a: ''
b: ''
'''
# Dataclass load testing # Dataclass load testing
@ -82,13 +135,13 @@ def test_dump_load_aliases():
""" Ensure dumping and loading `xx=Alias('x')` works. """ Ensure dumping and loading `xx=Alias('x')` works.
Ensure loading `{x=1, xx=1}` raises an error. Ensure loading `{x=1, xx=1}` raises an error.
Does not check constructor `Config(xx=1)`.""" Does not check constructor `Config(xx=1)`."""
@register_config @register_config(kw_only=False)
class Config: class Config:
x: int x: int
xx = Alias('x') xx = Alias('x')
# Test dumping # Test dumping
assert len(fields(Config)) == 1 assert len(attr.fields(Config)) == 1
cfg = Config(1) cfg = Config(1)
s = yaml.dump(cfg) s = yaml.dump(cfg)
assert s == '''\ assert s == '''\
@ -123,7 +176,7 @@ def test_dump_load_ignored():
xx = Ignored xx = Ignored
# Test dumping # Test dumping
assert len(fields(Config)) == 0 assert len(attr.fields(Config)) == 0
cfg = Config() cfg = Config()
s = yaml.dump(cfg) s = yaml.dump(cfg)
assert s == '''\ assert s == '''\
@ -162,13 +215,13 @@ def test_load_argument_validation():
def test_load_post_init(): def test_load_post_init():
""" yaml.load() does not natively call __post_init__. So @register_config modifies """ yaml.load() does not natively call __init__.
__setstate__ to call __post_init__. """ So @register_config modifies __setstate__ to call __attrs_post_init__. """
@register_config @register_config
class Foo: class Foo:
foo: int foo: int
def __post_init__(self): def __attrs_post_init__(self):
self.foo = 99 self.foo = 99
s = '''\ s = '''\