Skip to content

Commit 37c76a9

Browse files
authored
issue/500: 在 infinicore Python 包中接入 ntops
1 parent a3c5f3a commit 37c76a9

File tree

4 files changed

+60
-2
lines changed

4 files changed

+60
-2
lines changed

python/infinicore/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
short,
2525
uint8,
2626
)
27+
from infinicore.ntops import use_ntops
2728
from infinicore.ops.matmul import matmul
2829
from infinicore.ops.rearrange import rearrange
2930
from infinicore.tensor import (
@@ -62,6 +63,8 @@
6263
"long",
6364
"short",
6465
"uint8",
66+
# `ntops` integration.
67+
"use_ntops",
6568
# Operations.
6669
"matmul",
6770
"rearrange",

python/infinicore/ntops.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import sys
2+
3+
import infinicore
4+
5+
6+
def use_ntops():
7+
import ntops
8+
9+
return _TemporaryAttributes(
10+
(("ntops.torch.torch", infinicore),)
11+
+ tuple(
12+
(f"infinicore.{op_name}", getattr(ntops.torch, op_name))
13+
for op_name in ntops.torch.__all__
14+
)
15+
)
16+
17+
18+
class _TemporaryAttributes:
19+
def __init__(self, attribute_mappings):
20+
self._attribute_mappings = attribute_mappings
21+
22+
self._original_values = {}
23+
24+
def __enter__(self):
25+
for attr_path, new_value in self._attribute_mappings:
26+
parent, attr_name = self._resolve_path(attr_path)
27+
28+
try:
29+
self._original_values[attr_path] = getattr(parent, attr_name)
30+
except AttributeError:
31+
pass
32+
33+
setattr(parent, attr_name, new_value)
34+
35+
return self
36+
37+
def __exit__(self, exc_type, exc_value, traceback):
38+
for attr_path, _ in self._attribute_mappings:
39+
parent, attr_name = self._resolve_path(attr_path)
40+
41+
if attr_path in self._original_values:
42+
setattr(parent, attr_name, self._original_values[attr_path])
43+
else:
44+
delattr(parent, attr_name)
45+
46+
@staticmethod
47+
def _resolve_path(path):
48+
*parent_parts, attr_name = path.split(".")
49+
50+
curr = sys.modules[parent_parts[0]]
51+
52+
for part in parent_parts[1:]:
53+
curr = getattr(curr, part)
54+
55+
return curr, attr_name

python/infinicore/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def ndim(self):
3232
return self._underlying.ndim
3333

3434
def data_ptr(self):
35-
return self._underlying.data_ptr
35+
return self._underlying.data_ptr()
3636

3737
def size(self, dim=None):
3838
if dim is None:

src/infinicore/pybind11/tensor.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ inline void bind(py::module &m) {
1717
.def_property_readonly("dtype", [](const Tensor &tensor) { return tensor->dtype(); })
1818
.def_property_readonly("device", [](const Tensor &tensor) { return tensor->device(); })
1919

20-
.def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast<uintptr_t>(tensor->data()); })
20+
.def("data_ptr", [](const Tensor &tensor) { return reinterpret_cast<std::uintptr_t>(tensor->data()); })
2121
.def("size", [](const Tensor &tensor, std::size_t dim) { return tensor->size(dim); })
2222
.def("stride", [](const Tensor &tensor, std::size_t dim) { return tensor->stride(dim); })
2323
.def("numel", [](const Tensor &tensor) { return tensor->numel(); })

0 commit comments

Comments
 (0)