diff --git a/nnsmith/abstract/dtype.py b/nnsmith/abstract/dtype.py index 45bcb67..3972d0a 100644 --- a/nnsmith/abstract/dtype.py +++ b/nnsmith/abstract/dtype.py @@ -51,31 +51,7 @@ def is_float(self): @staticmethod def from_str(s): - return { - "f16": DType.float16, - "f32": DType.float32, - "f64": DType.float64, - "u8": DType.uint8, - "i8": DType.int8, - "i32": DType.int32, - "i64": DType.int64, - "c64": DType.complex64, - "c128": DType.complex128, - "float16": DType.float16, - "float32": DType.float32, - "float64": DType.float64, - "uint8": DType.uint8, - "uint16": DType.uint16, - "uint32": DType.uint32, - "uint64": DType.uint64, - "int8": DType.int8, - "int16": DType.int16, - "int32": DType.int32, - "int64": DType.int64, - "complex64": DType.complex64, - "complex128": DType.complex128, - "bool": DType.bool, - }[s] + return DType._FROM_STR_MAP[s] def numpy(self): return { @@ -191,6 +167,9 @@ def sizeof(self) -> int: DType.bool: 1, # Follow C/C++ convention. }[self] +DType._FROM_STR_MAP = {e.name: e for e in DType} +DType._FROM_STR_MAP.update({e.short(): e for e in DType}) + # "DTYPE_GEN*" means data types used for symbolic generation. # "DTYPE_GEN_ALL" is surely a subset of all types but it is