|
17 | 17 | """dtype class.""" |
18 | 18 | # pylint: disable=invalid-name |
19 | 19 | from enum import IntEnum |
20 | | -import numpy as np |
21 | 20 |
|
22 | 21 | from . import core |
23 | 22 |
|
@@ -58,22 +57,7 @@ class dtype(str): |
58 | 57 |
|
59 | 58 | __slots__ = ["__tvm_ffi_dtype__"] |
60 | 59 |
|
61 | | - NUMPY_DTYPE_TO_STR = { |
62 | | - np.dtype(np.bool_): "bool", |
63 | | - np.dtype(np.int8): "int8", |
64 | | - np.dtype(np.int16): "int16", |
65 | | - np.dtype(np.int32): "int32", |
66 | | - np.dtype(np.int64): "int64", |
67 | | - np.dtype(np.uint8): "uint8", |
68 | | - np.dtype(np.uint16): "uint16", |
69 | | - np.dtype(np.uint32): "uint32", |
70 | | - np.dtype(np.uint64): "uint64", |
71 | | - np.dtype(np.float16): "float16", |
72 | | - np.dtype(np.float32): "float32", |
73 | | - np.dtype(np.float64): "float64", |
74 | | - } |
75 | | - if hasattr(np, "float_"): |
76 | | - NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64" |
| 60 | + NUMPY_DTYPE_TO_STR = {} |
77 | 61 |
|
78 | 62 | def __new__(cls, content): |
79 | 63 | content = str(content) |
@@ -122,6 +106,28 @@ def lanes(self): |
122 | 106 | return self.__tvm_ffi_dtype__.lanes |
123 | 107 |
|
124 | 108 |
|
| 109 | +try: |
| 110 | + # this helps to make numpy as optional |
| 111 | + # although almost in all cases we want numpy |
| 112 | + import numpy as np |
| 113 | + |
| 114 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.bool_)] = "bool" |
| 115 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int8)] = "int8" |
| 116 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int16)] = "int16" |
| 117 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int32)] = "int32" |
| 118 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.int64)] = "int64" |
| 119 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint8)] = "uint8" |
| 120 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint16)] = "uint16" |
| 121 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint32)] = "uint32" |
| 122 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.uint64)] = "uint64" |
| 123 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float16)] = "float16" |
| 124 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float32)] = "float32" |
| 125 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float64)] = "float64" |
| 126 | + if hasattr(np, "float_"): |
| 127 | + dtype.NUMPY_DTYPE_TO_STR[np.dtype(np.float_)] = "float64" |
| 128 | +except ImportError: |
| 129 | + pass |
| 130 | + |
125 | 131 | try: |
126 | 132 | import ml_dtypes |
127 | 133 |
|
|
0 commit comments