kopia lustrzana https://github.com/corrscope/corrscope
(Config YAML) Add support for always dumping default fields
@register_config(always_dump='field names || *')pull/357/head
rodzic
1124c46f1e
commit
7f0e1ccaba
|
@ -1,4 +1,5 @@
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
from dataclasses import dataclass, fields
|
from dataclasses import dataclass, fields
|
||||||
from ruamel.yaml import yaml_object, YAML
|
from ruamel.yaml import yaml_object, YAML
|
||||||
|
@ -22,59 +23,84 @@ class MyYAML(YAML):
|
||||||
yaml = MyYAML()
|
yaml = MyYAML()
|
||||||
|
|
||||||
|
|
||||||
class OvgenError(Exception):
|
def register_config(cls=None, *, always_dump: str = ''):
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def __getstate__(self):
|
|
||||||
""" Returns all non-default fields. """
|
|
||||||
state = {}
|
|
||||||
cls = type(self)
|
|
||||||
|
|
||||||
for field in fields(self):
|
|
||||||
name = field.name
|
|
||||||
value = getattr(self, name)
|
|
||||||
default = getattr(cls, name, object())
|
|
||||||
|
|
||||||
if value != default:
|
|
||||||
state[name] = value
|
|
||||||
|
|
||||||
return state
|
|
||||||
|
|
||||||
|
|
||||||
def __setstate__(self, state):
|
|
||||||
""" Checks that all fields match their correct types. """
|
|
||||||
self.__dict__.update(state)
|
|
||||||
for field in fields(self):
|
|
||||||
key = field.name
|
|
||||||
value = getattr(self, key)
|
|
||||||
typ = field.type
|
|
||||||
|
|
||||||
if not isinstance(value, typ):
|
|
||||||
name = type(self).__name__
|
|
||||||
raise OvgenError(f'{name}.{key} was supplied {repr(value)}, should be of type {typ.__name__}')
|
|
||||||
|
|
||||||
if hasattr(self, '__post_init__'):
|
|
||||||
self.__post_init__()
|
|
||||||
|
|
||||||
|
|
||||||
def register_config(cls):
|
|
||||||
""" Marks class as @dataclass, and enables YAML dumping (excludes default fields).
|
""" Marks class as @dataclass, and enables YAML dumping (excludes default fields).
|
||||||
|
|
||||||
dataclasses.dataclass is compatible with yaml.register_class.
|
dataclasses.dataclass is compatible with yaml.register_class.
|
||||||
typing.NamedTuple is incompatible.
|
typing.NamedTuple is incompatible.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def decorator(cls: type):
|
||||||
|
cls.__getstate__ = _ConfigMixin.__getstate__
|
||||||
|
cls.__setstate__ = _ConfigMixin.__setstate__
|
||||||
|
cls.always_dump = always_dump
|
||||||
|
|
||||||
|
# https://stackoverflow.com/a/51497219/2683842
|
||||||
|
# YAML().register_class(cls) works... on versions more recent than 2018-07-12.
|
||||||
|
return yaml_object(yaml)(
|
||||||
|
dataclass(cls)
|
||||||
|
)
|
||||||
|
|
||||||
|
if cls is not None:
|
||||||
|
return decorator(cls)
|
||||||
|
else:
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
# __init__-less non-dataclasses are also compatible with yaml.register_class.
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass()
|
||||||
|
class _ConfigMixin:
|
||||||
|
"""
|
||||||
|
Class is unused. __getstate__ and __setstate__ are assigned into other classes.
|
||||||
|
Ideally I'd use inheritance, but @yaml_object and @dataclass rely on decorators,
|
||||||
|
and I want @register_config to Just Work and not need inheritance.
|
||||||
|
"""
|
||||||
|
always_dump: ClassVar[str]
|
||||||
|
|
||||||
# SafeRepresenter.represent_yaml_object() uses __getstate__ to dump objects.
|
# SafeRepresenter.represent_yaml_object() uses __getstate__ to dump objects.
|
||||||
cls.__getstate__ = __getstate__
|
def __getstate__(self):
|
||||||
|
""" Returns all fields with non-default value, or appeear in
|
||||||
|
self.always_dump. """
|
||||||
|
|
||||||
|
always_dump = set(self.always_dump.split())
|
||||||
|
dump_all = ('*' in always_dump)
|
||||||
|
|
||||||
|
state = {}
|
||||||
|
cls = type(self)
|
||||||
|
|
||||||
|
for field in fields(self):
|
||||||
|
name = field.name
|
||||||
|
value = getattr(self, name)
|
||||||
|
|
||||||
|
if dump_all or name in always_dump:
|
||||||
|
state[name] = value
|
||||||
|
continue
|
||||||
|
|
||||||
|
default = getattr(cls, name, object())
|
||||||
|
if value != default:
|
||||||
|
state[name] = value
|
||||||
|
|
||||||
|
return state
|
||||||
|
|
||||||
# SafeConstructor.construct_yaml_object() uses __setstate__ to load objects.
|
# SafeConstructor.construct_yaml_object() uses __setstate__ to load objects.
|
||||||
cls.__setstate__ = __setstate__
|
def __setstate__(self, state):
|
||||||
|
""" Checks that all fields match their correct types. """
|
||||||
|
self.__dict__.update(state)
|
||||||
|
for field in fields(self):
|
||||||
|
key = field.name
|
||||||
|
value = getattr(self, key)
|
||||||
|
typ = field.type
|
||||||
|
|
||||||
# https://stackoverflow.com/a/51497219/2683842
|
if not isinstance(value, typ):
|
||||||
# YAML().register_class(cls) works... on versions more recent than 2018-07-12.
|
name = type(self).__name__
|
||||||
return yaml_object(yaml)(
|
raise OvgenError(f'{name}.{key} was supplied {repr(value)}, should be of type {typ.__name__}')
|
||||||
dataclass(cls)
|
|
||||||
)
|
if hasattr(self, '__post_init__'):
|
||||||
|
self.__post_init__()
|
||||||
|
|
||||||
|
|
||||||
|
class OvgenError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# __init__-less non-dataclasses are also compatible with yaml.register_class.
|
|
||||||
|
|
|
@ -34,19 +34,43 @@ def test_yaml_object():
|
||||||
assert s == '!Bar {}\n'
|
assert s == '!Bar {}\n'
|
||||||
|
|
||||||
|
|
||||||
def test_dump_exclude_defaults():
|
def test_dump_defaults():
|
||||||
@register_config
|
@register_config
|
||||||
class DefaultConfig:
|
class Config:
|
||||||
a: str = 'a'
|
a: str = 'a'
|
||||||
b: str = 'b'
|
b: str = 'b'
|
||||||
|
|
||||||
s = yaml.dump(DefaultConfig('alpha'))
|
s = yaml.dump(Config('alpha'))
|
||||||
assert 'b:' not in s
|
|
||||||
assert s == '''\
|
assert s == '''\
|
||||||
!DefaultConfig
|
!Config
|
||||||
a: alpha
|
a: alpha
|
||||||
'''
|
'''
|
||||||
|
|
||||||
|
@register_config(always_dump='a b')
|
||||||
|
class Config:
|
||||||
|
a: str = 'a'
|
||||||
|
b: str = 'b'
|
||||||
|
c: str = 'c'
|
||||||
|
|
||||||
|
s = yaml.dump(Config())
|
||||||
|
assert s == '''\
|
||||||
|
!Config
|
||||||
|
a: a
|
||||||
|
b: b
|
||||||
|
'''
|
||||||
|
|
||||||
|
@register_config(always_dump='*')
|
||||||
|
class Config:
|
||||||
|
a: str = 'a'
|
||||||
|
b: str = 'b'
|
||||||
|
|
||||||
|
s = yaml.dump(Config())
|
||||||
|
assert s == '''\
|
||||||
|
!Config
|
||||||
|
a: a
|
||||||
|
b: b
|
||||||
|
'''
|
||||||
|
|
||||||
|
|
||||||
def test_load_type_checking():
|
def test_load_type_checking():
|
||||||
@register_config
|
@register_config
|
||||||
|
|
Ładowanie…
Reference in New Issue