Skip to content

Commit 2ff76a6

Browse files
albanDChao1Han
authored andcommitted
Add device daemon (pytorch#131814)
Base implementation aiming towards pytorch/rfcs#64 Details of the implementation and next steps in https://github.com/pytorch/pytorch/blob/gh/albanD/3/head/test/cpp_extensions/open_registration_extension/README.md Pull Request resolved: pytorch#131814 Approved by: https://github.com/ezyang
1 parent c059e15 commit 2ff76a6

File tree

10 files changed

+614
-41
lines changed

10 files changed

+614
-41
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ per-file-ignores =
5757
torch/distributed/_tensor/_collective_utils.py: TOR901
5858
# This is a full package that happen to live within the test
5959
# folder, so ok to skip
60-
test/cpp_extensions/open_registration_extension/pytorch_openreg/__init__.py: TOR901
60+
test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py: TOR901
6161
optional-ascii-coding = True
6262
exclude =
6363
./.git,
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend in core.
1+
This folder contains a self-contained example of a PyTorch out-of-tree backend leveraging the "PrivateUse1" backend from core.
22

33
## How to use
44
Install as standalone with `python setup.py develop` (or install) from this folder.
@@ -8,6 +8,23 @@ You can run test via `python test/test_openreg.py`.
88
For simplicity anything that can be implemented from python is done so.
99
A real implementation will most likely want to call these different APIs from c++ directly.
1010

11-
The current version send everything back to python and is missing most implementations in python. The only one available is the one used by the autograd engine to check how many workers to spawn.
11+
The current version sends everything back to python and contains enough implementation to run basic model, transfer host/device and printing.
1212

13-
Next step is to create the device daemon so we can actually provide and allocator and create memory, then start using features and re-route all missing methods to daemon as appropriate.
13+
The codebase is split as follows:
14+
- `pytorch_openreg/__init__.py` imports torch to get core state initialized, imports `._aten_impl` to register our aten op implementations to torch, imports `.C` to load our c++ extension that registers more ops, allocator and hooks and finally renames the PrivateUse1 backend and register our python-side module.
15+
- `pytorch_openreg/_aten_impl.py` does two main things. Use the `_register_same_name()` function to register hooks from c++ (like getDevice, getStream, etc) and send them to our device daemon. Define a new `torch.Library` that registers a fallback that will be called whenever a backend kernel for PrivateUse1 is called. It contains the logic to handle all kind of native functions, computing the output metadata, allocating it and only calling into the device daemon to perform computation
16+
- `pytorch_openreg/_device_daemon.py` contains the Allocator (responsible for allocating memory on the device side, as int8 buffers, and recreating nice looking Tensors on the device side to be able to use aten ops to run code there), `run_op` that is the logic running on the device side to perform compute (for simplicity of coverage, we are re-building full blown Tensors here and calling aten ops on them). It also contains the Daemon responsible for the device worker process and sending data back and forth.
17+
- `pytorch_openreg/_meta_parser.py` mainly contain utilities to send objects over the wire from the user process to the device process. The main class there is `OpenRegTensorMeta` that contains all the metadata sent to the device which should be enough for it to populate the output Tensor.
18+
19+
## Next steps
20+
21+
Currently, the autograd test is disabled because it's missing the getStream implementation.
22+
The main next step would be to:
23+
- Split the daemon into a proper user-process driver vs device-process executor. The main goal would be to better mimick which information is held on the user-process side and when we're actually communicating with the device. In particular current device or stream should be user-process informations.
24+
- Add Stream/Event system. Most likely by having multiple requests queue that go to the device from the driver.
25+
- Add RNG Generator.
26+
- Add Pinned memory and HostAllocator.
27+
28+
Longer term:
29+
- Replace the current `open_registration_extension.cpp` test in PyTorch CI with this.
30+
- Build this module in the CI environment and enable Device-generic tests on this device.
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,13 @@
11
import torch
22

3-
4-
# Global properties of our device
5-
NUM_DEVICES = 7
6-
73
# Create our python implementation dict so that the C++ module
84
# can access it during its initialization
9-
_IMPL_REGISTRY = {}
10-
11-
# Load the C++ Module
12-
import pytorch_openreg._C # noqa: F401
13-
14-
15-
# Define all the implementations in the registry
16-
def register(fn):
17-
_IMPL_REGISTRY[fn.__name__[1:]] = fn
18-
return fn
5+
# Also register aten impls
6+
from ._aten_impl import _IMPL_REGISTRY as _IMPL_REGISTRY # noqa: F401
197

208

21-
@register
22-
def _deviceCount():
23-
return NUM_DEVICES
9+
# Load the C++ Module
10+
import pytorch_openreg._C # noqa: F401 # usort: skip
2411

2512

2613
# Module used for our backend
@@ -31,15 +18,3 @@ class _OpenRegMod:
3118
# Set all the appropriate state on PyTorch
3219
torch.utils.rename_privateuse1_backend("openreg")
3320
torch._register_device_module("openreg", _OpenRegMod())
34-
35-
_openreg_lib = torch.library.Library("_", "IMPL") # ignore TOR901
36-
37-
38-
def _openreg_kernel_fallback(op, *args, **kwargs):
39-
print("Calling ", op)
40-
assert op is torch.ops.aten.empty.memory_format
41-
# FIXME: this returns a cpu Tensor which is NOT ok.
42-
return torch.empty(args[0])
43-
44-
45-
_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import logging
2+
3+
import torch
4+
from torch.utils._pytree import tree_any
5+
6+
7+
log = logging.getLogger(__name__)
8+
9+
from ._device_daemon import daemon
10+
from ._meta_parser import prepare_for_sending, to_device_no_copy
11+
12+
13+
_IMPL_REGISTRY = {}
14+
15+
16+
# Define all the implementations in the registry
17+
def _register_same_name(name, with_log=False):
18+
def _(*args, **kwargs):
19+
if with_log:
20+
log.info("Calling hook %s", name)
21+
return daemon.exec(name, *args, **kwargs)
22+
23+
_IMPL_REGISTRY[name] = _
24+
25+
26+
_register_same_name("deviceCount")
27+
_register_same_name("getDevice")
28+
_register_same_name("uncheckedSetDevice")
29+
_register_same_name("exchangeDevice")
30+
_register_same_name("malloc", True)
31+
_register_same_name("free", True)
32+
33+
_openreg_lib = torch.library.Library("_", "IMPL")
34+
35+
36+
def _openreg_kernel_fallback(op, *args, **kwargs):
37+
log.info("Calling kernel %s", op)
38+
39+
# Special ops needed to avoid infinite recursion
40+
if op is torch.ops.aten._copy_from.default:
41+
from_, to_ = args
42+
if from_.device.type == to_.device.type:
43+
assert from_.device.type == "openreg"
44+
op = torch.ops.aten.copy_.default
45+
# handled below as a regular copy
46+
elif from_.device.type == "openreg":
47+
args, _ = prepare_for_sending((from_,), {})
48+
host_mem = daemon.exec("send_data", *args)
49+
return to_.copy_(host_mem)
50+
elif to_.device.type == "openreg":
51+
args, _ = prepare_for_sending((to_,), {})
52+
daemon.exec("recv_data", from_, *args)
53+
return to_
54+
else:
55+
raise RuntimeError("Should not happen")
56+
elif op is torch.ops.aten.set_.source_Tensor:
57+
return torch.ops.aten.set_.source_Storage_storage_offset(
58+
args[0],
59+
args[1].untyped_storage(),
60+
args[1].storage_offset(),
61+
args[1].size(),
62+
args[1].stride(),
63+
)
64+
elif op is torch.ops.aten._local_scalar_dense.default:
65+
args, _ = prepare_for_sending(args, {})
66+
host_mem = daemon.exec("send_data", *args)
67+
return host_mem.item()
68+
69+
op_name = None
70+
post_process = None
71+
if "out" in op._overloadname:
72+
# Note that all structured native op will call here
73+
if isinstance(kwargs["out"], tuple):
74+
raise RuntimeError(f"out= variant {op} with tuple out= not supported")
75+
if kwargs["out"].nelement() == 0:
76+
# Out variant that needs a resize, convert to an out of place
77+
# and handle generically below
78+
orig_out = kwargs["out"]
79+
del kwargs["out"]
80+
if op._overloadname != "out":
81+
raise RuntimeError(
82+
"Cannot retranslate non-default out= variant form 0 size"
83+
)
84+
op = op.overloadpacket.default
85+
86+
def _post_process():
87+
nonlocal real_res
88+
orig_out.set_(real_res)
89+
real_res = orig_out
90+
91+
post_process = _post_process
92+
93+
else:
94+
# No metadata update to do, just run the op on the device
95+
op_name = op.overloadpacket._qualified_op_name
96+
real_res = kwargs["out"]
97+
elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)):
98+
# No Tensor argument means factory function
99+
# They should decompose and be handled in our c++ side directly
100+
raise RuntimeError(f"{op} not handled yet.")
101+
elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default:
102+
# Only handle inplace ops returning their first arg
103+
assert len(args) >= 1, f"Inplace {op} needs at least one arg"
104+
assert (
105+
len(op._schema.returns) == 1
106+
), f"NYI Inplace {op} with more than one return"
107+
op_name = op.overloadpacket._qualified_op_name
108+
real_res = args[0]
109+
elif any(r.alias_info is not None for r in op._schema.returns):
110+
# View ops
111+
if op is torch.ops.aten.view.default:
112+
return torch.ops.aten._unsafe_view(*args, **kwargs)
113+
raise RuntimeError(f"{op} view op is not handled yet")
114+
115+
if op_name is None:
116+
# 1. Compute updated metadata
117+
if torch.Tag.dynamic_output_shape not in op.tags:
118+
# Usual case: run the meta op to see the output metadata
119+
meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs)
120+
meta_res = op(*meta_args, **meta_kwargs)
121+
122+
# 2. Allocate the output
123+
real_res, _ = to_device_no_copy("openreg", meta_res, {})
124+
else:
125+
# Slow version for data-dependent functions:
126+
# Run the op on the device just to get the output shape
127+
args_, kwargs_ = prepare_for_sending(args, kwargs)
128+
shape = daemon.exec(
129+
"get_op_output_shape",
130+
op.overloadpacket._qualified_op_name,
131+
args_,
132+
kwargs_,
133+
)
134+
135+
# 2. Allocate the output
136+
real_res = args[0].new(shape)
137+
138+
# 3. Move to out variant
139+
kwargs["out"] = real_res
140+
# Let overload resolution find the out= overload
141+
op_name = op.overloadpacket._qualified_op_name
142+
143+
# 4. Run the compute and populate the output on the device
144+
args, kwargs = prepare_for_sending(args, kwargs)
145+
daemon.exec("run_op", op_name, args, kwargs)
146+
147+
if post_process is not None:
148+
post_process()
149+
150+
return real_res
151+
152+
153+
_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1")

0 commit comments

Comments
 (0)