diff --git a/corrscope/config.py b/corrscope/config.py index 9b57703..29212e9 100644 --- a/corrscope/config.py +++ b/corrscope/config.py @@ -117,36 +117,55 @@ class DumpableAttrs: whitespace-separated list of fields to always dump. """ - # Private variable, to avoid clashing with subclass attributes. - __always_dump: ClassVar[FrozenSet[str]] = frozenset() - if TYPE_CHECKING: def __init__(self, *args, **kwargs): pass - def __init_subclass__(cls, kw_only: bool = False, always_dump: str = ""): + # Private variable, to avoid clashing with subclass attributes. + __always_dump: ClassVar[FrozenSet[str]] = frozenset() + __exclude: ClassVar[FrozenSet[str]] = frozenset() + + def __init_subclass__( + cls, kw_only: bool = False, always_dump: str = "", exclude: str = "" + ): _yaml_loadable(attr.dataclass(cls, kw_only=kw_only)) # Merge always_dump with superclass's __always_dump. super_always_dump = cls.__always_dump + super_exclude = cls.__exclude assert type(super_always_dump) == frozenset + assert type(super_exclude) == frozenset cls.__always_dump = super_always_dump | frozenset(always_dump.split()) - del super_always_dump, always_dump + cls.__exclude = super_exclude | frozenset(exclude.split()) + del super_always_dump, always_dump + del super_exclude, exclude + + all_fields = {f.name for f in attr.fields(cls)} dump_fields = cls.__always_dump - {"*"} # remove "*" if exists + exclude_fields = cls.__exclude + if "*" in cls.__always_dump: assert ( not dump_fields ), f"Invalid always_dump, contains * and elements {dump_fields}" + for exclude_field in exclude_fields: + assert ( + exclude_field in all_fields + ), f'Invalid exclude, contains "{exclude_field}" missing from class {cls.__name__}' + else: - all_fields = {f.name for f in attr.fields(cls)} + assert ( + not exclude_fields + ), f"Invalid exclude, always_dump does not contain *" + for dump_field in dump_fields: assert ( dump_field in all_fields - ), f'Invalid always_dump="...{dump_field}" missing from class {cls.__name__}' + ), f'Invalid always_dump, contains "{dump_field}" missing from class {cls.__name__}' # SafeRepresenter.represent_yaml_object() uses __getstate__ to dump objects. def __getstate__(self) -> Dict[str, Any]: @@ -155,6 +174,7 @@ class DumpableAttrs: always_dump = self.__always_dump dump_all = "*" in always_dump + exclude = self.__exclude state = {} cls = type(self) @@ -166,9 +186,14 @@ class DumpableAttrs: but I'd lose structure checking, converters, and __attrs_post_init__. """ - if dump_all or attr_name in always_dump: + # Dump values marked as always dumped. + if attr_name in always_dump: return True + if dump_all and attr_name not in exclude: + return True + + # Don't dump default values. if field.default == value: return False # noinspection PyTypeChecker,PyUnresolvedReferences @@ -178,6 +203,7 @@ class DumpableAttrs: ): return False + # Dump values with different or missing defaults. return True for field in attr.fields(cls): diff --git a/corrscope/renderer.py b/corrscope/renderer.py index 45b2a31..baaf630 100644 --- a/corrscope/renderer.py +++ b/corrscope/renderer.py @@ -132,7 +132,9 @@ class Font(DumpableAttrs, always_dump="*"): toString: str = None -class RendererConfig(DumpableAttrs, always_dump="*"): +class RendererConfig( + DumpableAttrs, always_dump="*", exclude="viewport_width viewport_height" +): width: int height: int line_width: float = with_units("px", default=1.5) diff --git a/tests/test_config.py b/tests/test_config.py index a624a91..e786755 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -192,6 +192,39 @@ c: 3 ) +def test_exclude_dump(): + """ + Test that the exclude="" parameter can remove fields from always_dump="*". + """ + + class Config(DumpableAttrs, always_dump="*", exclude="b"): + a: int = 1 + b: int = 2 + + s = yaml.dump(Config()) + assert ( + s + == """\ +!Config +a: 1 +""" + ) + + class Special(Config, exclude="d"): + c: int = 3 + d: int = 4 + + s = yaml.dump(Special()) + assert ( + s + == """\ +!Special +a: 1 +c: 3 +""" + ) + + # Dataclass load testing @@ -371,6 +404,16 @@ def test_always_dump_validate(): class Foo(DumpableAttrs, always_dump="bar"): foo: int + with pytest.raises(AssertionError): + + class Foo(DumpableAttrs, exclude="foo"): + foo: int + + with pytest.raises(AssertionError): + + class Foo(DumpableAttrs, always_dump="*", exclude="bar"): + foo: int + # Test properties of our ruamel.yaml instance. def test_dump_no_line_break():