Skip to content

Commit

Permalink
Use semi-official StrEnum backport instead of our own code (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmke8 authored May 14, 2024
2 parents c99ee3c + 5a3a465 commit 394e9f6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 61 deletions.
14 changes: 12 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ readme = "README.md"
python = ">=3.10,<3.13"
typing-extensions = ">=4.5.0"
numpy = "^1.23.2"
"backports.strenum" = { version = "^1.3.1", python = "<3.11" }

# wandb dependencies
pandas = { version = "^1.5.0", optional = true }
Expand Down
64 changes: 5 additions & 59 deletions ranzen/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@

from ranzen.types import Addable, SizedDataset, Subset

if sys.version_info >= (3, 11):
from enum import StrEnum
else:
from backports.strenum import StrEnum

__all__ = [
"AddDict",
"Split",
Expand Down Expand Up @@ -118,65 +123,6 @@ def str_to_enum(str_: str | E, *, enum: type[E]) -> E:
)


if sys.version_info >= (3, 11):
# will be available in python 3.11
from enum import StrEnum
else:
#
# the following is copied straight from https://github.com/python/cpython/blob/3.11/Lib/enum.py
#
# DO NOT CHANGE THIS CODE!
#
class ReprEnum(Enum):
"""
Only changes the repr(), leaving str() and format() to the mixed-in type.
"""

_S = TypeVar("_S", bound="StrEnum")

class StrEnum(str, ReprEnum):
"""
Enum where members are also (and must be) strings
"""

_value_: str

def __new__(cls: type[_S], *values: str) -> _S:
"values must already be of type `str`"
if len(values) > 3:
raise TypeError("too many arguments for str(): %r" % (values,))
if len(values) == 1:
# it must be a string
if not isinstance(values[0], str): # pyright: ignore
raise TypeError("%r is not a string" % (values[0],))
if len(values) >= 2:
# check that encoding argument is a string
if not isinstance(v1 := values[1], str): # pyright: ignore
raise TypeError("encoding must be a string, not %r" % (v1,))
if len(values) == 3:
# check that errors argument is a string
if not isinstance(v2 := values[2], str): # pyright: ignore
raise TypeError("errors must be a string, not %r" % (v2))
value = str(*values)
member = str.__new__(cls, value)
member._value_ = value
return member

def __str__(self) -> str:
return str.__str__(self)

def _generate_next_value_( # type: ignore
name: str,
start: int,
count: int,
last_values: list[Any],
) -> str:
"""
Return the lower-cased version of the member name.
"""
return name.lower()


_KT = TypeVar("_KT")
_VT = TypeVar("_VT", bound=Addable)
_VT2 = TypeVar("_VT2", bound=Addable)
Expand Down

0 comments on commit 394e9f6

Please sign in to comment.