|
| 1 | +import logging |
| 2 | + |
| 3 | +import torch |
| 4 | +from torch.utils._pytree import tree_any |
| 5 | + |
| 6 | + |
| 7 | +log = logging.getLogger(__name__) |
| 8 | + |
| 9 | +from ._device_daemon import daemon |
| 10 | +from ._meta_parser import prepare_for_sending, to_device_no_copy |
| 11 | + |
| 12 | + |
| 13 | +_IMPL_REGISTRY = {} |
| 14 | + |
| 15 | + |
| 16 | +# Define all the implementations in the registry |
| 17 | +def _register_same_name(name, with_log=False): |
| 18 | + def _(*args, **kwargs): |
| 19 | + if with_log: |
| 20 | + log.info("Calling hook %s", name) |
| 21 | + return daemon.exec(name, *args, **kwargs) |
| 22 | + |
| 23 | + _IMPL_REGISTRY[name] = _ |
| 24 | + |
| 25 | + |
| 26 | +_register_same_name("deviceCount") |
| 27 | +_register_same_name("getDevice") |
| 28 | +_register_same_name("uncheckedSetDevice") |
| 29 | +_register_same_name("exchangeDevice") |
| 30 | +_register_same_name("malloc", True) |
| 31 | +_register_same_name("free", True) |
| 32 | + |
| 33 | +_openreg_lib = torch.library.Library("_", "IMPL") |
| 34 | + |
| 35 | + |
| 36 | +def _openreg_kernel_fallback(op, *args, **kwargs): |
| 37 | + log.info("Calling kernel %s", op) |
| 38 | + |
| 39 | + # Special ops needed to avoid infinite recursion |
| 40 | + if op is torch.ops.aten._copy_from.default: |
| 41 | + from_, to_ = args |
| 42 | + if from_.device.type == to_.device.type: |
| 43 | + assert from_.device.type == "openreg" |
| 44 | + op = torch.ops.aten.copy_.default |
| 45 | + # handled below as a regular copy |
| 46 | + elif from_.device.type == "openreg": |
| 47 | + args, _ = prepare_for_sending((from_,), {}) |
| 48 | + host_mem = daemon.exec("send_data", *args) |
| 49 | + return to_.copy_(host_mem) |
| 50 | + elif to_.device.type == "openreg": |
| 51 | + args, _ = prepare_for_sending((to_,), {}) |
| 52 | + daemon.exec("recv_data", from_, *args) |
| 53 | + return to_ |
| 54 | + else: |
| 55 | + raise RuntimeError("Should not happen") |
| 56 | + elif op is torch.ops.aten.set_.source_Tensor: |
| 57 | + return torch.ops.aten.set_.source_Storage_storage_offset( |
| 58 | + args[0], |
| 59 | + args[1].untyped_storage(), |
| 60 | + args[1].storage_offset(), |
| 61 | + args[1].size(), |
| 62 | + args[1].stride(), |
| 63 | + ) |
| 64 | + elif op is torch.ops.aten._local_scalar_dense.default: |
| 65 | + args, _ = prepare_for_sending(args, {}) |
| 66 | + host_mem = daemon.exec("send_data", *args) |
| 67 | + return host_mem.item() |
| 68 | + |
| 69 | + op_name = None |
| 70 | + post_process = None |
| 71 | + if "out" in op._overloadname: |
| 72 | + # Note that all structured native op will call here |
| 73 | + if isinstance(kwargs["out"], tuple): |
| 74 | + raise RuntimeError(f"out= variant {op} with tuple out= not supported") |
| 75 | + if kwargs["out"].nelement() == 0: |
| 76 | + # Out variant that needs a resize, convert to an out of place |
| 77 | + # and handle generically below |
| 78 | + orig_out = kwargs["out"] |
| 79 | + del kwargs["out"] |
| 80 | + if op._overloadname != "out": |
| 81 | + raise RuntimeError( |
| 82 | + "Cannot retranslate non-default out= variant form 0 size" |
| 83 | + ) |
| 84 | + op = op.overloadpacket.default |
| 85 | + |
| 86 | + def _post_process(): |
| 87 | + nonlocal real_res |
| 88 | + orig_out.set_(real_res) |
| 89 | + real_res = orig_out |
| 90 | + |
| 91 | + post_process = _post_process |
| 92 | + |
| 93 | + else: |
| 94 | + # No metadata update to do, just run the op on the device |
| 95 | + op_name = op.overloadpacket._qualified_op_name |
| 96 | + real_res = kwargs["out"] |
| 97 | + elif not tree_any(lambda obj: isinstance(obj, torch.Tensor), (args, kwargs)): |
| 98 | + # No Tensor argument means factory function |
| 99 | + # They should decompose and be handled in our c++ side directly |
| 100 | + raise RuntimeError(f"{op} not handled yet.") |
| 101 | + elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: |
| 102 | + # Only handle inplace ops returning their first arg |
| 103 | + assert len(args) >= 1, f"Inplace {op} needs at least one arg" |
| 104 | + assert ( |
| 105 | + len(op._schema.returns) == 1 |
| 106 | + ), f"NYI Inplace {op} with more than one return" |
| 107 | + op_name = op.overloadpacket._qualified_op_name |
| 108 | + real_res = args[0] |
| 109 | + elif any(r.alias_info is not None for r in op._schema.returns): |
| 110 | + # View ops |
| 111 | + if op is torch.ops.aten.view.default: |
| 112 | + return torch.ops.aten._unsafe_view(*args, **kwargs) |
| 113 | + raise RuntimeError(f"{op} view op is not handled yet") |
| 114 | + |
| 115 | + if op_name is None: |
| 116 | + # 1. Compute updated metadata |
| 117 | + if torch.Tag.dynamic_output_shape not in op.tags: |
| 118 | + # Usual case: run the meta op to see the output metadata |
| 119 | + meta_args, meta_kwargs = to_device_no_copy("meta", args, kwargs) |
| 120 | + meta_res = op(*meta_args, **meta_kwargs) |
| 121 | + |
| 122 | + # 2. Allocate the output |
| 123 | + real_res, _ = to_device_no_copy("openreg", meta_res, {}) |
| 124 | + else: |
| 125 | + # Slow version for data-dependent functions: |
| 126 | + # Run the op on the device just to get the output shape |
| 127 | + args_, kwargs_ = prepare_for_sending(args, kwargs) |
| 128 | + shape = daemon.exec( |
| 129 | + "get_op_output_shape", |
| 130 | + op.overloadpacket._qualified_op_name, |
| 131 | + args_, |
| 132 | + kwargs_, |
| 133 | + ) |
| 134 | + |
| 135 | + # 2. Allocate the output |
| 136 | + real_res = args[0].new(shape) |
| 137 | + |
| 138 | + # 3. Move to out variant |
| 139 | + kwargs["out"] = real_res |
| 140 | + # Let overload resolution find the out= overload |
| 141 | + op_name = op.overloadpacket._qualified_op_name |
| 142 | + |
| 143 | + # 4. Run the compute and populate the output on the device |
| 144 | + args, kwargs = prepare_for_sending(args, kwargs) |
| 145 | + daemon.exec("run_op", op_name, args, kwargs) |
| 146 | + |
| 147 | + if post_process is not None: |
| 148 | + post_process() |
| 149 | + |
| 150 | + return real_res |
| 151 | + |
| 152 | + |
| 153 | +_openreg_lib.fallback(_openreg_kernel_fallback, dispatch_key="PrivateUse1") |
0 commit comments