-
Notifications
You must be signed in to change notification settings - Fork 108
[Experimental] Insert prims.shape to ensure NumberProxy is materialized for all dynamic shapes
#2746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Co-authored-by: crcrpar <[email protected]>
36dee09 to
fee3d7b
Compare
benchmark_inference.pyprims.shape to ensure NumberProxy is materialized for all dynamic shapes
| shape_attr = None | ||
| if history is not None: | ||
| shape_attr = ProvenanceRecord( | ||
| PseudoInst.LOAD_ATTR, inputs=[copy.copy(history), wrap_const("shape").provenance] | ||
| ) | ||
|
|
||
| def _dim_history(idx: int) -> ProvenanceRecord | None: | ||
| if shape_attr is None: | ||
| return None | ||
| return ProvenanceRecord(PseudoInst.BINARY_SUBSCR, inputs=[shape_attr, wrap_const(idx).provenance]) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As suggested in another comment, I think this should be pulled out into a separate PR. When that happens...
This is worrisome. The error we are encountering is that history is None, right? I don't like the idea that we are creating a TensorProxy for a tensor that doesn't have a history... It also doesn't feel right that we are propagating that lack of history forward. In what case does the input tensor not have a history?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the traceback where a TensorProxy is created without history:
...
File "/opt/pytorch/lightning-thunder/thunder/torch/experimental/dtensor_torch_and_prims.py", line 403, in dtensor_from_local_meta
res = proxify_dtensor(res)
^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/torch/experimental/dtensor_utils.py", line 21, in wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/torch/experimental/dtensor_proxy.py", line 165, in proxify_dtensor
local_tensor_proxy = proxy(t, history=history)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/core/proxies.py", line 2104, in proxy
return tensorproxy(x, name=name, history=history)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/core/proxies.py", line 2025, in tensorproxy
raise Exception(msg)
So this comes from a DTensor meta function, which does not have to do with provenance tracking. So in such cases, propagating the lack of history should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but I thought that meta functions were frequently invoked during the tracing and provenance tracking process... Each time a primitive symbol is executed during tracing, it invokes its meta. Or did you mean something else?
In this prim, we've got the argument x. Isn't it a TensorProxy, probably with some history?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point was that meta functions in general do not return a proxy with history.
lightning-thunder/thunder/core/prims.py
Lines 3817 to 3818 in 4b6272c
| def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy: | |
| return TensorProxy(like=a) |
Here, Symbols are wrapped by interpreter_needs_wrap, which wraps the output of meta functions in a WrappedValue and attaches ProvenanceRecord here. So, in my understanding, meta functions under Symbol.__call__ and proxies they deal with do not need to know provenance at all.
I was unsure, so I inserted print(f"{x.history = }").
Patch
diff --git a/thunder/torch/experimental/dtensor_torch_and_prims.py b/thunder/torch/experimental/dtensor_torch_and_prims.py
index dfb090c4..0b2e1285 100644
--- a/thunder/torch/experimental/dtensor_torch_and_prims.py
+++ b/thunder/torch/experimental/dtensor_torch_and_prims.py
@@ -388,20 +388,21 @@ if torch.distributed.is_available():
def dtensor_from_local_meta(
x,
mesh,
placements,
*,
run_check: bool = False,
shape: torch.Size | None = None,
stride: tuple[int, ...] | None = None,
):
+ print(f"{x.history = }")
res = run_with_fake_tensor(
DTensor.from_local, x, mesh, placements, run_check=run_check, shape=shape, stride=stride
)
from thunder.torch.experimental.dtensor_proxy import proxify_dtensorIt only printed x.history = None.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried thunder.jit(lambda x: x.exp()) and realized that the proxy input of exp's meta function had history. I am not sure why it's different.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existence of history depends on where in the interpretation/compilation a tensor proxy is created. If a proxy is created as the function is being interpreted (in interpreter.py), we should expect to see histories populated with ProvenanceRecords. These records are connected with the creation of the prologue trace. If a tensorproxy is created in one of the many optimization passes of the compilation stage, there is no need to keep up with creating ProvenanceRecords, and so histories are often empty. From the backtrace you pasted, I can't tell if this error is happening in the interpretation stage where there should be a history, or in the compilation stage where we don't care.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your information! I don't remember the entire stack trace, but what I remember is that dtensor_from_local_meta was never called with a TensorProxy with history. I think it is related to how local tensor proxies are created from dtensors.
In my understanding, when the history of a tensor proxy is not recorded, its shape can also be history-less anyway.
| ) -> Any: | ||
| nv_fill_value = getnv(fill_value, fd, lc_to_nv_map) | ||
| nvdtype = lcdtype_to_nvdtype(dtype) | ||
| nv_shape = getnv(shape, fd, lc_to_nv_map) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, the changes in this file should be a separate PR. When that happens...
Should this be nv_shape = getnv(shape, fd, lc_to_nv_map, inline_number=True) as Kshiteej suggests in #2677 (comment)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I thought I should first figure out what the inline_number option does. Since we're dealing with metadata here, I have decided to put inline_number=True back.
Builds upon #2745. Closes #2677.
The following command runs successfully.
torchrun --nproc-per-node 2 thunder/benchmarks/benchmark_inference.py \ --num-iterations 10 --mode thunder --thunder-cache "symbolic values"