From 72b363ab6a4f38f3f0a725f657b1de0660caa9fb Mon Sep 17 00:00:00 2001 From: Jay Date: Wed, 8 Oct 2025 15:38:03 -0400 Subject: [PATCH 1/2] fix: add missing integer types to dtype from_str There's no reason why `i16` should be missing, especially since `i8`, `i32`, `i64` exist. Also added `u16`, `u32`, `u64` for similar reasons; there's already `u8`, `uint8`, `uint16`, `uint32`, `uint64`. --- nnsmith/abstract/dtype.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nnsmith/abstract/dtype.py b/nnsmith/abstract/dtype.py index 45bcb67..8bc3c40 100644 --- a/nnsmith/abstract/dtype.py +++ b/nnsmith/abstract/dtype.py @@ -56,7 +56,11 @@ def from_str(s): "f32": DType.float32, "f64": DType.float64, "u8": DType.uint8, + "u16": DType.uint16, + "u32": DType.uint32, + "u64": DType.uint64, "i8": DType.int8, + "i16": DType.int16, "i32": DType.int32, "i64": DType.int64, "c64": DType.complex64, From 65bd5a2dd70db4442cf4edc2b42ee242a35952e2 Mon Sep 17 00:00:00 2001 From: Jay Date: Wed, 8 Oct 2025 16:02:55 -0400 Subject: [PATCH 2/2] refactor: implement from_str method using generated dtype lookup map Replaced the from_str method's dictionary with a lookup map for better maintainability --- nnsmith/abstract/dtype.py | 33 ++++----------------------------- 1 file changed, 4 insertions(+), 29 deletions(-) diff --git a/nnsmith/abstract/dtype.py b/nnsmith/abstract/dtype.py index 8bc3c40..3972d0a 100644 --- a/nnsmith/abstract/dtype.py +++ b/nnsmith/abstract/dtype.py @@ -51,35 +51,7 @@ def is_float(self): @staticmethod def from_str(s): - return { - "f16": DType.float16, - "f32": DType.float32, - "f64": DType.float64, - "u8": DType.uint8, - "u16": DType.uint16, - "u32": DType.uint32, - "u64": DType.uint64, - "i8": DType.int8, - "i16": DType.int16, - "i32": DType.int32, - "i64": DType.int64, - "c64": DType.complex64, - "c128": DType.complex128, - "float16": DType.float16, - "float32": DType.float32, - "float64": DType.float64, - "uint8": DType.uint8, - "uint16": DType.uint16, - "uint32": DType.uint32, - "uint64": DType.uint64, - "int8": DType.int8, - "int16": DType.int16, - "int32": DType.int32, - "int64": DType.int64, - "complex64": DType.complex64, - "complex128": DType.complex128, - "bool": DType.bool, - }[s] + return DType._FROM_STR_MAP[s] def numpy(self): return { @@ -195,6 +167,9 @@ def sizeof(self) -> int: DType.bool: 1, # Follow C/C++ convention. }[self] +DType._FROM_STR_MAP = {e.name: e for e in DType} +DType._FROM_STR_MAP.update({e.short(): e for e in DType}) + # "DTYPE_GEN*" means data types used for symbolic generation. # "DTYPE_GEN_ALL" is surely a subset of all types but it is