Switch from keyword_dataclasses to attrs, fix Python 3.7

keyword_dataclasses works on 3.6 and 3.6 only. attrs works everywhere.

Validators and converters are not used yet.
pull/357/head
nyanpasu64 2018-11-25 05:55:55 -08:00
rodzic 8cb921cb14
commit a78c3712c9
9 zmienionych plików z 112 dodań i 1239 usunięć

Wyświetl plik

@ -1,13 +1,18 @@
from io import StringIO
from typing import ClassVar, TYPE_CHECKING
from ovgenpy.utils.keyword_dataclasses import dataclass, fields, Field, MISSING
# from dataclasses import dataclass, fields
import attr
from ruamel.yaml import yaml_object, YAML, Representer
if TYPE_CHECKING:
from enum import Enum
__all__ = ['yaml',
'register_config', 'kw_config', 'Alias', 'Ignored',
'register_enum', 'OvgenError']
# Setup YAML loading (yaml object).
class MyYAML(YAML):
@ -31,12 +36,8 @@ _yaml_loadable = yaml_object(yaml)
# Setup configuration load/dump infrastructure.
def register_config(cls=None, *, always_dump: str = ''):
""" Marks class as @dataclass, and enables YAML dumping (excludes default fields).
dataclasses.dataclass is compatible with yaml_object().
typing.NamedTuple is incompatible.
"""
def register_config(cls=None, *, kw_only=False, always_dump: str = ''):
""" Marks class as attrs, and enables YAML dumping (excludes default fields). """
def decorator(cls: type):
cls.__getstate__ = _ConfigMixin.__getstate__
@ -45,7 +46,7 @@ def register_config(cls=None, *, always_dump: str = ''):
# https://stackoverflow.com/a/51497219/2683842
# YAML().register_class(cls) works... on versions more recent than 2018-07-12.
return _yaml_loadable(dataclass(cls))
return _yaml_loadable(attr.dataclass(cls, kw_only=kw_only))
if cls is not None:
return decorator(cls)
@ -53,7 +54,11 @@ def register_config(cls=None, *, always_dump: str = ''):
return decorator
@dataclass()
def kw_config(*args, **kwargs):
return register_config(*args, **kwargs, kw_only=True)
@attr.dataclass()
class _ConfigMixin:
"""
Class is unused. __getstate__ and __setstate__ are assigned into other classes.
@ -73,7 +78,7 @@ class _ConfigMixin:
state = {}
cls = type(self)
for field in fields(self): # type: Field
for field in attr.fields(cls):
name = field.name
value = getattr(self, name)
@ -83,8 +88,9 @@ class _ConfigMixin:
if field.default == value:
continue
if field.default_factory is not MISSING \
and field.default_factory() == value:
# noinspection PyTypeChecker,PyUnresolvedReferences
if isinstance(field.default, attr.Factory) \
and field.default.factory() == value:
continue
state[name] = value
@ -117,7 +123,7 @@ class _ConfigMixin:
self.__dict__ = obj.__dict__
@dataclass
@attr.dataclass
class Alias:
"""
@register_config
@ -130,12 +136,6 @@ class Alias:
Ignored = object()
# Unused
# def default(value):
# """Supplies a mutable default value for a dataclass field."""
# string = repr(value)
# return field(default_factory=lambda: eval(string))
# Setup Enum load/dump infrastructure

Wyświetl plik

@ -12,7 +12,7 @@ class LayoutConfig:
nrows: Optional[int] = None
ncols: Optional[int] = None
def __post_init__(self):
def __attrs_post_init__(self):
if not self.nrows:
self.nrows = None
if not self.ncols:

Wyświetl plik

@ -6,15 +6,15 @@ from fractions import Fraction
from types import SimpleNamespace
from typing import Optional, List, Union, TYPE_CHECKING
import attr
from ovgenpy import outputs as outputs_
from ovgenpy.channel import Channel, ChannelConfig
from ovgenpy.config import register_config, register_enum, Ignored
from ovgenpy.config import kw_config, register_enum, Ignored
from ovgenpy.renderer import MatplotlibRenderer, RendererConfig
from ovgenpy.layout import LayoutConfig
from ovgenpy.triggers import ITriggerConfig, CorrelationTriggerConfig, PerFrameCache
from ovgenpy.util import pushd, coalesce
from ovgenpy.utils import keyword_dataclasses as dc
from ovgenpy.utils.keyword_dataclasses import field, InitVar
from ovgenpy.wave import Wave
if TYPE_CHECKING:
@ -32,7 +32,7 @@ class BenchmarkMode(IntEnum):
OUTPUT = 3
@register_config(always_dump='render_subfps begin_time end_time subsampling')
@kw_config(always_dump='render_subfps begin_time end_time subsampling')
class Config:
master_audio: Optional[str]
begin_time: float = 0
@ -49,7 +49,7 @@ class Config:
# trigger_subsampling and render_subsampling override subsampling.
trigger_subsampling: int = None
render_subsampling: int = None
subsampling: InitVar[int] = 1
_subsampling: int = 1
trigger_width: int = 1
render_width: int = 1
@ -67,7 +67,7 @@ class Config:
player: outputs_.IOutputConfig = outputs_.FFplayOutputConfig()
encoder: outputs_.IOutputConfig = outputs_.FFmpegOutputConfig(None)
show_internals: List[str] = field(default_factory=list)
show_internals: List[str] = attr.Factory(list)
benchmark_mode: Union[str, BenchmarkMode] = BenchmarkMode.NONE
# region Legacy Fields
@ -79,7 +79,7 @@ class Config:
def width_s(self) -> float:
return self.width_ms / 1000
def __post_init__(self, subsampling):
def __attrs_post_init__(self):
# Cast benchmark_mode to enum.
try:
if not isinstance(self.benchmark_mode, BenchmarkMode):
@ -90,6 +90,7 @@ class Config:
f'{[el.name for el in BenchmarkMode]}')
# Compute trigger_subsampling and render_subsampling.
subsampling = self._subsampling
self.trigger_subsampling = coalesce(self.trigger_subsampling, subsampling)
self.render_subsampling = coalesce(self.render_subsampling, subsampling)
@ -118,7 +119,7 @@ def default_config(**kwargs) -> Config:
layout=LayoutConfig(ncols=2),
render=RendererConfig(1280, 800),
)
return dc.replace(cfg, **kwargs)
return attr.evolve(cfg, **kwargs)
class Ovgen:
@ -189,9 +190,9 @@ class Ovgen:
extra_outputs = SimpleNamespace()
if internals:
from ovgenpy.outputs import FFplayOutputConfig
from ovgenpy.utils.keyword_dataclasses import replace
import attr
no_audio = replace(self.cfg, master_audio='')
no_audio = attr.evolve(self.cfg, master_audio='')
ovgen = self

Wyświetl plik

@ -3,7 +3,7 @@ from typing import Optional, List, TYPE_CHECKING, Any
import matplotlib
import numpy as np
from ovgenpy.utils.keyword_dataclasses import dataclass
import attr
from ovgenpy.config import register_config
from ovgenpy.layout import RendererLayout, LayoutConfig
@ -43,7 +43,7 @@ class RendererConfig:
create_window: bool = False
@dataclass
@attr.dataclass
class LineParam:
color: Any = None

Wyświetl plik

@ -5,10 +5,10 @@ from typing import TYPE_CHECKING, Type, Tuple, Optional, ClassVar
import numpy as np
from scipy import signal
from scipy.signal import windows
import attr
from ovgenpy.config import register_config, OvgenError, Alias
from ovgenpy.config import kw_config, OvgenError, Alias
from ovgenpy.util import find, obj_name
from ovgenpy.utils.keyword_dataclasses import dataclass
from ovgenpy.utils.windows import midpad, leftpad
from ovgenpy.wave import FLOAT
@ -18,7 +18,7 @@ if TYPE_CHECKING:
# Abstract classes
@dataclass
@attr.dataclass
class ITriggerConfig:
cls: ClassVar[Type['Trigger']]
@ -82,7 +82,7 @@ class Trigger(ABC):
...
@dataclass
@attr.dataclass
class PerFrameCache:
"""
The estimated period of a wave region (Wave.get_around())
@ -101,7 +101,7 @@ class PerFrameCache:
# CorrelationTrigger
@register_config(always_dump='''
@kw_config(always_dump='''
use_edge_trigger
edge_strength
responsiveness
@ -129,7 +129,7 @@ class CorrelationTriggerConfig(ITriggerConfig):
use_edge_trigger: bool = True
# endregion
def __post_init__(self):
def __attrs_post_init__(self):
self._validate_param('lag_prevention', 0, 1)
self._validate_param('responsiveness', 0, 1)
# TODO trigger_falloff >= 0
@ -420,7 +420,7 @@ class PostTrigger(Trigger, ABC):
# Local edge-finding trigger
@register_config(always_dump='strength')
@kw_config(always_dump='strength')
class LocalPostTriggerConfig(ITriggerConfig):
strength: float # Coefficient
@ -497,7 +497,7 @@ def seq_along(a: np.ndarray):
# ZeroCrossingTrigger
@register_config
@kw_config
class ZeroCrossingTriggerConfig(ITriggerConfig):
pass
@ -551,7 +551,7 @@ class ZeroCrossingTrigger(PostTrigger):
# NullTrigger
@register_config
@kw_config
class NullTriggerConfig(ITriggerConfig):
pass

Wyświetl plik

@ -1,12 +1,12 @@
from typing import Optional
import numpy as np
from ovgenpy.config import dataclass
import attr
from scipy.io import wavfile
# Internal class, not exposed via YAML
@dataclass
@attr.dataclass
class _WaveConfig:
amplification: float = 1

Wyświetl plik

@ -10,5 +10,6 @@ setup(
author_email='',
description='',
tests_require=['pytest', 'pytest-pycharm', 'hypothesis', 'delayed-assert'],
install_requires=['numpy', 'scipy', 'click', 'matplotlib', 'ruamel.yaml']
install_requires=['numpy', 'scipy', 'click', 'matplotlib', 'ruamel.yaml',
'attrs>=18.2.0']
)

Wyświetl plik

@ -1,15 +1,12 @@
# noinspection PyUnresolvedReferences
import sys
import pytest
from ruamel.yaml import yaml_object
from ovgenpy.config import register_config, yaml, Alias, Ignored
from ovgenpy.config import register_config, yaml, Alias, Ignored, kw_config
# YAML Idiosyncrasies: https://docs.saltstack.com/en/develop/topics/troubleshooting/yaml_idiosyncrasies.html
# Load/dump infrastructure testing
from ovgenpy.utils.keyword_dataclasses import fields
import attr
def test_register_config():
@ -18,7 +15,7 @@ def test_register_config():
foo: int
bar: int
s = yaml.dump(Foo(1, 2))
s = yaml.dump(Foo(foo=1, bar=2))
assert s == '''\
!Foo
foo: 1
@ -26,6 +23,17 @@ bar: 2
'''
def test_kw_config():
@kw_config
class Foo:
foo: int = 1
bar: int
obj = Foo(bar=2)
assert obj.foo == 1
assert obj.bar == 2
def test_yaml_object():
@yaml_object(yaml)
class Bar:
@ -75,6 +83,51 @@ b: b
'''
def test_dump_default_factory():
""" Ensure default factories are not dumped, unless attribute present
in `always_dump`.
Based on `attrs.Factory`. """
@register_config
class Config:
# Equivalent to attr.ib(factory=str)
# See https://www.attrs.org/en/stable/types.html
a: str = attr.Factory(str)
b: str = attr.Factory(str)
s = yaml.dump(Config('alpha'))
assert s == '''\
!Config
a: alpha
'''
@register_config(always_dump='a b')
class Config:
a: str = attr.Factory(str)
b: str = attr.Factory(str)
c: str = attr.Factory(str)
s = yaml.dump(Config())
assert s == '''\
!Config
a: ''
b: ''
'''
@register_config(always_dump='*')
class Config:
a: str = attr.Factory(str)
b: str = attr.Factory(str)
s = yaml.dump(Config())
assert s == '''\
!Config
a: ''
b: ''
'''
# Dataclass load testing
@ -82,13 +135,13 @@ def test_dump_load_aliases():
""" Ensure dumping and loading `xx=Alias('x')` works.
Ensure loading `{x=1, xx=1}` raises an error.
Does not check constructor `Config(xx=1)`."""
@register_config
@register_config(kw_only=False)
class Config:
x: int
xx = Alias('x')
# Test dumping
assert len(fields(Config)) == 1
assert len(attr.fields(Config)) == 1
cfg = Config(1)
s = yaml.dump(cfg)
assert s == '''\
@ -123,7 +176,7 @@ def test_dump_load_ignored():
xx = Ignored
# Test dumping
assert len(fields(Config)) == 0
assert len(attr.fields(Config)) == 0
cfg = Config()
s = yaml.dump(cfg)
assert s == '''\
@ -162,13 +215,13 @@ def test_load_argument_validation():
def test_load_post_init():
""" yaml.load() does not natively call __post_init__. So @register_config modifies
__setstate__ to call __post_init__. """
""" yaml.load() does not natively call __init__.
So @register_config modifies __setstate__ to call __attrs_post_init__. """
@register_config
class Foo:
foo: int
def __post_init__(self):
def __attrs_post_init__(self):
self.foo = 99
s = '''\