-
Notifications
You must be signed in to change notification settings - Fork 108
[Experimental] Insert prims.shape to ensure NumberProxy is materialized for all dynamic shapes
#2746
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Experimental] Insert prims.shape to ensure NumberProxy is materialized for all dynamic shapes
#2746
Changes from 7 commits
edaf6ba
ade597e
d05dac5
bd7a6ed
9ed4bcd
f6a87a0
fee3d7b
db012b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1045,10 +1045,11 @@ def full( | |
| ) -> Any: | ||
| nv_fill_value = getnv(fill_value, fd, lc_to_nv_map) | ||
| nvdtype = lcdtype_to_nvdtype(dtype) | ||
| nv_shape = getnv(shape, fd, lc_to_nv_map) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, the changes in this file should be a separate PR. When that happens... Should this be
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I thought I should first figure out what the |
||
|
|
||
| _select_device(fd, device) | ||
|
|
||
| return fd.ops.full(shape, nv_fill_value, nvdtype) | ||
| return fd.ops.full(nv_shape, nv_fill_value, nvdtype) | ||
|
|
||
|
|
||
| register_supported(PrimIDs.FULL, full, _full_check) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As suggested in another comment, I think this should be pulled out into a separate PR. When that happens...
This is worrisome. The error we are encountering is that
historyis None, right? I don't like the idea that we are creating a TensorProxy for a tensor that doesn't have a history... It also doesn't feel right that we are propagating that lack of history forward. In what case does the input tensor not have a history?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the traceback where a
TensorProxyis created withouthistory:So this comes from a DTensor meta function, which does not have to do with provenance tracking. So in such cases, propagating the lack of history should be fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct me if I'm wrong, but I thought that meta functions were frequently invoked during the tracing and provenance tracking process... Each time a primitive symbol is executed during tracing, it invokes its meta. Or did you mean something else?
In this prim, we've got the argument
x. Isn't it a TensorProxy, probably with some history?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point was that meta functions in general do not return a proxy with history.
lightning-thunder/thunder/core/prims.py
Lines 3817 to 3818 in 4b6272c
Here,
Symbols are wrapped byinterpreter_needs_wrap, which wraps the output of meta functions in aWrappedValueand attachesProvenanceRecordhere. So, in my understanding, meta functions underSymbol.__call__and proxies they deal with do not need to know provenance at all.I was unsure, so I inserted
print(f"{x.history = }").Patch
It only printed
x.history = None.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried
thunder.jit(lambda x: x.exp())and realized that the proxy input ofexp's meta function had history. I am not sure why it's different.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The existence of history depends on where in the interpretation/compilation a tensor proxy is created. If a proxy is created as the function is being interpreted (in interpreter.py), we should expect to see histories populated with ProvenanceRecords. These records are connected with the creation of the prologue trace. If a tensorproxy is created in one of the many optimization passes of the compilation stage, there is no need to keep up with creating ProvenanceRecords, and so histories are often empty. From the backtrace you pasted, I can't tell if this error is happening in the interpretation stage where there should be a history, or in the compilation stage where we don't care.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your information! I don't remember the entire stack trace, but what I remember is that
dtensor_from_local_metawas never called with a TensorProxy with history. I think it is related to how local tensor proxies are created from dtensors.In my understanding, when the history of a tensor proxy is not recorded, its shape can also be history-less anyway.