Skip to content
5 changes: 5 additions & 0 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
from thunder.core.transform_common import (
Transform,
dce,
ensure_symbolic_shape_bindings,
remove_context_manager_prims_from_trace,
unwrap_return_value,
wrap_return_value_together_with_arguments,
Expand Down Expand Up @@ -544,6 +545,10 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
computation_trc = dce(computation_trc)
computation_traces.append(computation_trc)

if thunder.core.compile_data.using_symbolic_values():
computation_trc = ensure_symbolic_shape_bindings(computation_trc)
computation_traces.append(computation_trc)

if not cd.disable_torch_autograd_support:
tensor_cls = (pytorch.Tensor, TensorProxy)
requires_grad = any(isinstance(arg, tensor_cls) and arg.requires_grad for arg in computation_trc.args)
Expand Down
2 changes: 1 addition & 1 deletion thunder/core/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7041,7 +7041,7 @@ def _impl(fn, *args, **kwargs):
if (
unbound_fn_candidate is not None
and isinstance(unbound_fn_candidate, (WrapperDescriptorType, MethodDescriptorType))
and unbound_fn_candidate.__get__(slf) == fn
and unbound_fn_candidate.__get__(slf, type(slf)) == fn
):
unbound_fn = unbound_fn_candidate
break
Expand Down
14 changes: 12 additions & 2 deletions thunder/core/proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,12 +2010,22 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple =
distparallel_type = getattr(t, "distparallel_type", None)
_thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None)
if using_symbolic_values():
shape_attr = ProvenanceRecord(PseudoInst.LOAD_ATTR, inputs=[copy.copy(history), wrap_const("shape").provenance])
shape_attr = None
if history is not None:
shape_attr = ProvenanceRecord(
PseudoInst.LOAD_ATTR, inputs=[copy.copy(history), wrap_const("shape").provenance]
)

def _dim_history(idx: int) -> ProvenanceRecord | None:
if shape_attr is None:
return None
return ProvenanceRecord(PseudoInst.BINARY_SUBSCR, inputs=[shape_attr, wrap_const(idx).provenance])

Comment on lines +2013 to +2023
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As suggested in another comment, I think this should be pulled out into a separate PR. When that happens...

This is worrisome. The error we are encountering is that history is None, right? I don't like the idea that we are creating a TensorProxy for a tensor that doesn't have a history... It also doesn't feel right that we are propagating that lack of history forward. In what case does the input tensor not have a history?

Copy link
Collaborator Author

@shino16 shino16 Nov 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the traceback where a TensorProxy is created without history:

...
  File "/opt/pytorch/lightning-thunder/thunder/torch/experimental/dtensor_torch_and_prims.py", line 403, in dtensor_from_local_meta
    res = proxify_dtensor(res)
          ^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/torch/experimental/dtensor_utils.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/torch/experimental/dtensor_proxy.py", line 165, in proxify_dtensor
    local_tensor_proxy = proxy(t, history=history)
                         ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/core/proxies.py", line 2104, in proxy
    return tensorproxy(x, name=name, history=history)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/core/proxies.py", line 2025, in tensorproxy
    raise Exception(msg)

So this comes from a DTensor meta function, which does not have to do with provenance tracking. So in such cases, propagating the lack of history should be fine.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong, but I thought that meta functions were frequently invoked during the tracing and provenance tracking process... Each time a primitive symbol is executed during tracing, it invokes its meta. Or did you mean something else?

In this prim, we've got the argument x. Isn't it a TensorProxy, probably with some history?

Copy link
Collaborator Author

@shino16 shino16 Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point was that meta functions in general do not return a proxy with history.

def shallow_copy_meta(a: TensorProxy, /) -> TensorProxy:
return TensorProxy(like=a)

Here, Symbols are wrapped by interpreter_needs_wrap, which wraps the output of meta functions in a WrappedValue and attaches ProvenanceRecord here. So, in my understanding, meta functions under Symbol.__call__ and proxies they deal with do not need to know provenance at all.

I was unsure, so I inserted print(f"{x.history = }").

Patch
diff --git a/thunder/torch/experimental/dtensor_torch_and_prims.py b/thunder/torch/experimental/dtensor_torch_and_prims.py
index dfb090c4..0b2e1285 100644
--- a/thunder/torch/experimental/dtensor_torch_and_prims.py
+++ b/thunder/torch/experimental/dtensor_torch_and_prims.py
@@ -388,20 +388,21 @@ if torch.distributed.is_available():
 
     def dtensor_from_local_meta(
         x,
         mesh,
         placements,
         *,
         run_check: bool = False,
         shape: torch.Size | None = None,
         stride: tuple[int, ...] | None = None,
     ):
+        print(f"{x.history = }")
         res = run_with_fake_tensor(
             DTensor.from_local, x, mesh, placements, run_check=run_check, shape=shape, stride=stride
         )
         from thunder.torch.experimental.dtensor_proxy import proxify_dtensor

It only printed x.history = None.

Copy link
Collaborator Author

@shino16 shino16 Nov 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried thunder.jit(lambda x: x.exp()) and realized that the proxy input of exp's meta function had history. I am not sure why it's different.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existence of history depends on where in the interpretation/compilation a tensor proxy is created. If a proxy is created as the function is being interpreted (in interpreter.py), we should expect to see histories populated with ProvenanceRecords. These records are connected with the creation of the prologue trace. If a tensorproxy is created in one of the many optimization passes of the compilation stage, there is no need to keep up with creating ProvenanceRecords, and so histories are often empty. From the backtrace you pasted, I can't tell if this error is happening in the interpretation stage where there should be a history, or in the compilation stage where we don't care.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your information! I don't remember the entire stack trace, but what I remember is that dtensor_from_local_meta was never called with a TensorProxy with history. I think it is related to how local tensor proxies are created from dtensors.

In my understanding, when the history of a tensor proxy is not recorded, its shape can also be history-less anyway.

shape = tuple(
IntegerProxy(
None,
s,
history=ProvenanceRecord(PseudoInst.BINARY_SUBSCR, inputs=[shape_attr, wrap_const(idx).provenance]),
history=_dim_history(idx),
constraint=CONSTRAINT.CONSTRAINABLE,
)
for idx, s in enumerate(t.shape)
Expand Down
42 changes: 41 additions & 1 deletion thunder/core/transform_common.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from collections import defaultdict
import time
from typing import TYPE_CHECKING
from abc import ABC
Expand All @@ -10,7 +11,7 @@
import thunder
import thunder.core.prims as prims
from thunder.core.baseutils import BoundSymbolInterface, NumberProxyInterface
from thunder.core.proxies import Proxy, variableify, Variable
from thunder.core.proxies import Proxy, TensorProxy, variableify, Variable
from thunder.core.pytree import tree_flatten, tree_iter, tree_map
from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags
from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace
Expand Down Expand Up @@ -516,3 +517,42 @@ def is_context_manager_prim(bsym):
new_trace.bound_symbols = filtered_bsyms
new_trace.set_provenance(TraceProvenance("Remove context manager prims"))
return new_trace


def ensure_symbolic_shape_bindings(trace: Trace) -> Trace:
"""Insert prims.shape bound symbols for dynamically shaped tensors whose shape dims lack producers."""
tensor_map: dict[str, TensorProxy] = {}

def record_if_tensor(obj) -> None:
if isinstance(obj, TensorProxy):
tensor_map[obj.name] = obj

tree_map(record_if_tensor, (trace.args, trace.kwargs))

for bsym in trace.bound_symbols:
tree_map(record_if_tensor, bsym.output)

producer_map = producers(trace)
pending: dict[BoundSymbol, list[TensorProxy]] = defaultdict(list)
new_bsyms = []

for tensor in tensor_map.values():
needs_materialization = any(
isinstance(dim, Proxy) and producer_map.get(dim, None) is None for dim in tensor._shape
)
if needs_materialization:
tensor_producer = producer_map.get(tensor, None)
if tensor_producer is not None:
pending[tensor_producer].append(tensor)
else:
new_bsyms.append(prims.shape.bind(tensor, output=tensor._shape))

for bsym in trace.bound_symbols:
new_bsyms.append(bsym)
for tensor in pending.get(bsym, []):
new_bsyms.append(prims.shape.bind(tensor, output=tensor._shape))

new_trace = from_trace(trace)
new_trace.bound_symbols = new_bsyms
new_trace.set_provenance(TraceProvenance("Inserted prims.shape for symbolic dims"))
return new_trace
11 changes: 9 additions & 2 deletions thunder/core/update_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def _involves_viewed_args(bsym, viewed):
return any(isinstance(p, TensorProxy) and variableify(p) in viewed for p in bsym.flat_proxy_args)


def _static_numel(tensor: TensorProxy) -> int | None:
numel = getattr(tensor, "_numel", None)
return numel if isinstance(numel, int) else None


def replace_args_with_alias_map(
computation_trace: Trace,
alias_tensor_indices: list[list[int]],
Expand All @@ -69,9 +74,11 @@ def replace_args_with_alias_map(
arg = flat_args[indices[0]]
for idx in filter(lambda idx: idx < len(flat_args), indices[1:]):
arg_to_replace = flat_args[idx]
# Skip aliases with different numel (e.g., complex tensor and its real view)
# Skip aliases with different statically-known numel (e.g., complex tensor and its real view)
# These share storage but have incompatible element counts
if arg.numel != arg_to_replace.numel:
arg_numel = _static_numel(arg)
arg_to_replace_numel = _static_numel(arg_to_replace)
if arg_numel is not None and arg_to_replace_numel is not None and arg_numel != arg_to_replace_numel:
continue
reshaped_arg = arg
if arg_to_replace.shape != arg.shape:
Expand Down
16 changes: 2 additions & 14 deletions thunder/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1033,20 +1033,8 @@ def producers(trace_or_bsyms: TraceCtx | list[BoundSymbolInterface], *, _map_to_

vout = variableify(out)

# Checks if the proxy was also an input (in which case this is not its producers)
is_input: bool = False
for vin in bsym.flat_variableified_proxy_args:
if vin == vout:
is_input = True
break

if is_input:
continue

if _map_to_numbers:
producers[out] = idx
else:
producers[out] = bsym
if vout not in bsym.flat_variableified_proxy_args:
producers[out] = idx if _map_to_numbers else bsym

return producers

Expand Down
11 changes: 11 additions & 0 deletions thunder/executors/data_dependent_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,11 @@ def __init__(self, trace: TraceCtx):

producers = utils.producers(trace, _map_to_numbers=True)
consumers = utils.consumers(trace, _map_to_numbers=True)
trace_input_vars = set()
flat_inputs, _ = utils.tree_flatten((trace.args, trace.kwargs))
for obj in flat_inputs:
if isinstance(obj, Proxy):
trace_input_vars.add(variableify(obj))

# Note, even though BoundSymbolInterface is hashable, it's hash is very slow
# as it appears to be far off from being universal.
Expand All @@ -117,6 +122,12 @@ def __init__(self, trace: TraceCtx):
if not isinstance(inp, Proxy):
continue

producer_id = producers.get(inp, None)
if producer_id is None and variableify(inp) in trace_input_vars:
# TODO: This should not happen
# I observed (i23,) = prims.shape(l_cache_position_) hits this branch
# check(False, "Unpacked trace input consumed")
continue
producer_id = producers[inp]
parent = bsym_id_to_node_map[producer_id]
node.parents.add(parent)
Expand Down
3 changes: 2 additions & 1 deletion thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,10 +1045,11 @@ def full(
) -> Any:
nv_fill_value = getnv(fill_value, fd, lc_to_nv_map)
nvdtype = lcdtype_to_nvdtype(dtype)
nv_shape = getnv(shape, fd, lc_to_nv_map)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, the changes in this file should be a separate PR. When that happens...

Should this be nv_shape = getnv(shape, fd, lc_to_nv_map, inline_number=True) as Kshiteej suggests in #2677 (comment)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I thought I should first figure out what the inline_number option does. Since we're dealing with metadata here, I have decided to put inline_number=True back.


_select_device(fd, device)

return fd.ops.full(shape, nv_fill_value, nvdtype)
return fd.ops.full(nv_shape, nv_fill_value, nvdtype)


register_supported(PrimIDs.FULL, full, _full_check)
Expand Down
14 changes: 10 additions & 4 deletions thunder/executors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,13 @@ def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]):
for bsym in self.bound_symbols:
flatouts = bsym.flat_outs

produces.update(
variableify(x) for x in flatouts if isinstance(x, Proxy) and producers[x] in self.bound_symbols
)
for x in flatouts:
if not isinstance(x, Proxy):
continue
producer_bsym = producers.get(x, None)
# TODO: x should have a producer
if producer_bsym in self.bound_symbols:
produces.add(variableify(x))

# Short-circuits if the symbol is a comment, because comments don't consume anything
# Note that comments may produce things
Expand All @@ -69,7 +73,9 @@ def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]):
for x in consumes:
x = unvariableify(x)

if producers[x] not in self.bound_symbols:
# TODO: x should have a producer
producer_bsym = producers.get(x, None)
if producer_bsym not in self.bound_symbols:
inputs.add(variableify(x))

# Outputs are things this produces that are consumed after it
Expand Down
Loading