-
Notifications
You must be signed in to change notification settings - Fork 528
/
Copy pathunion.py
77 lines (59 loc) · 2.17 KB
/
union.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-ignore-all-errors
import functools
from dataclasses import fields
from typing import Hashable, Set
class _UnionTag(str):
_cls: Hashable
@staticmethod
def create(t, cls):
tag = _UnionTag(t)
assert not hasattr(tag, "_cls")
tag._cls = cls
return tag
def __eq__(self, cmp) -> bool:
assert isinstance(cmp, str)
other = str(cmp)
assert other in _get_field_names(
self._cls
), f"{other} is not a valid tag for {self._cls}. Available tags: {_get_field_names(self._cls)}"
return str(self) == other
def __hash__(self):
return hash(str(self))
@functools.lru_cache(maxsize=None)
def _get_field_names(cls) -> Set[str]:
return {f.name for f in fields(cls)}
class _Union:
_type: _UnionTag
@classmethod
def create(cls, **kwargs):
assert len(kwargs) == 1
obj = cls(**{**{f.name: None for f in fields(cls)}, **kwargs}) # type: ignore[arg-type]
obj._type = _UnionTag.create(next(iter(kwargs.keys())), cls)
return obj
def __post_init__(self):
assert not any(f.name in ("type", "_type", "create", "value") for f in fields(self)) # type: ignore[arg-type, misc]
@property
def type(self) -> str:
try:
return self._type
except AttributeError as e:
raise RuntimeError(
f"Please use {type(self).__name__}.create to instantiate the union type."
) from e
@property
def value(self):
return getattr(self, self.type)
def __getattribute__(self, name):
attr = super().__getattribute__(name)
if attr is None and name in _get_field_names(type(self)) and name != self.type: # type: ignore[arg-type]
raise AttributeError(f"Field {name} is not set.")
return attr
def __str__(self):
return self.__repr__()
def __repr__(self):
return f"{type(self).__name__}({self.type}={getattr(self, self.type)})"