diff --git a/thunder/__init__.py b/thunder/__init__.py index c0483f8c3a..16bfd78cbb 100644 --- a/thunder/__init__.py +++ b/thunder/__init__.py @@ -51,6 +51,7 @@ from thunder.core.transform_common import ( Transform, dce, + ensure_symbolic_shape_bindings, remove_context_manager_prims_from_trace, unwrap_return_value, wrap_return_value_together_with_arguments, @@ -544,6 +545,10 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com computation_trc = dce(computation_trc) computation_traces.append(computation_trc) + if thunder.core.compile_data.using_symbolic_values(): + computation_trc = ensure_symbolic_shape_bindings(computation_trc) + computation_traces.append(computation_trc) + if not cd.disable_torch_autograd_support: tensor_cls = (pytorch.Tensor, TensorProxy) requires_grad = any(isinstance(arg, tensor_cls) and arg.requires_grad for arg in computation_trc.args) diff --git a/thunder/core/interpreter.py b/thunder/core/interpreter.py index 228e1a4b9f..235c6689ea 100644 --- a/thunder/core/interpreter.py +++ b/thunder/core/interpreter.py @@ -7041,7 +7041,7 @@ def _impl(fn, *args, **kwargs): if ( unbound_fn_candidate is not None and isinstance(unbound_fn_candidate, (WrapperDescriptorType, MethodDescriptorType)) - and unbound_fn_candidate.__get__(slf) == fn + and unbound_fn_candidate.__get__(slf, type(slf)) == fn ): unbound_fn = unbound_fn_candidate break diff --git a/thunder/core/proxies.py b/thunder/core/proxies.py index 26efb361a7..5abd412adc 100644 --- a/thunder/core/proxies.py +++ b/thunder/core/proxies.py @@ -2009,14 +2009,23 @@ def tensorproxy(t: torch.Tensor, /, *, name: None | str, history: None | tuple = # See Note [DistributedDataParallel and distparallel_type] distparallel_type = getattr(t, "distparallel_type", None) _thunder_fsdp_padding_size = getattr(t, "_thunder_fsdp_padding_size", None) - # For parameters, shapes should be static. if using_symbolic_values() and not isinstance(t, torch.nn.Parameter): - shape_attr = ProvenanceRecord(PseudoInst.LOAD_ATTR, inputs=[copy.copy(history), wrap_const("shape").provenance]) + shape_attr = None + if history is not None: + shape_attr = ProvenanceRecord( + PseudoInst.LOAD_ATTR, inputs=[copy.copy(history), wrap_const("shape").provenance] + ) + + def _dim_history(idx: int) -> ProvenanceRecord | None: + if shape_attr is None: + return None + return ProvenanceRecord(PseudoInst.BINARY_SUBSCR, inputs=[shape_attr, wrap_const(idx).provenance]) + shape = tuple( IntegerProxy( None, s, - history=ProvenanceRecord(PseudoInst.BINARY_SUBSCR, inputs=[shape_attr, wrap_const(idx).provenance]), + history=_dim_history(idx), constraint=CONSTRAINT.CONSTRAINABLE, ) for idx, s in enumerate(t.shape) diff --git a/thunder/core/transform_common.py b/thunder/core/transform_common.py index bc14a671f2..9521565cbf 100644 --- a/thunder/core/transform_common.py +++ b/thunder/core/transform_common.py @@ -1,4 +1,5 @@ from __future__ import annotations +from collections import defaultdict import time from typing import TYPE_CHECKING from abc import ABC @@ -10,7 +11,7 @@ import thunder import thunder.core.prims as prims from thunder.core.baseutils import BoundSymbolInterface, NumberProxyInterface -from thunder.core.proxies import Proxy, variableify, Variable +from thunder.core.proxies import Proxy, TensorProxy, variableify, Variable from thunder.core.pytree import tree_flatten, tree_iter, tree_map from thunder.core.symbol import BoundSymbol, BoundSymbolRHS, has_tags from thunder.core.trace import from_trace, TraceProvenance, TraceCtx as Trace @@ -516,3 +517,42 @@ def is_context_manager_prim(bsym): new_trace.bound_symbols = filtered_bsyms new_trace.set_provenance(TraceProvenance("Remove context manager prims")) return new_trace + + +def ensure_symbolic_shape_bindings(trace: Trace) -> Trace: + """Insert prims.shape bound symbols for dynamically shaped tensors whose shape dims lack producers.""" + tensor_map: dict[str, TensorProxy] = {} + + def record_if_tensor(obj) -> None: + if isinstance(obj, TensorProxy): + tensor_map[obj.name] = obj + + tree_map(record_if_tensor, (trace.args, trace.kwargs)) + + for bsym in trace.bound_symbols: + tree_map(record_if_tensor, bsym.output) + + producer_map = producers(trace) + pending: dict[BoundSymbol, list[TensorProxy]] = defaultdict(list) + new_bsyms = [] + + for tensor in tensor_map.values(): + needs_materialization = any( + isinstance(dim, Proxy) and producer_map.get(dim, None) is None for dim in tensor._shape + ) + if needs_materialization: + tensor_producer = producer_map.get(tensor, None) + if tensor_producer is not None: + pending[tensor_producer].append(tensor) + else: + new_bsyms.append(prims.shape.bind(tensor, output=tensor._shape)) + + for bsym in trace.bound_symbols: + new_bsyms.append(bsym) + for tensor in pending.get(bsym, []): + new_bsyms.append(prims.shape.bind(tensor, output=tensor._shape)) + + new_trace = from_trace(trace) + new_trace.bound_symbols = new_bsyms + new_trace.set_provenance(TraceProvenance("Inserted prims.shape for symbolic dims")) + return new_trace diff --git a/thunder/core/update_aliases.py b/thunder/core/update_aliases.py index ffbb59e632..ead770cb9d 100644 --- a/thunder/core/update_aliases.py +++ b/thunder/core/update_aliases.py @@ -55,6 +55,11 @@ def _involves_viewed_args(bsym, viewed): return any(isinstance(p, TensorProxy) and variableify(p) in viewed for p in bsym.flat_proxy_args) +def _static_numel(tensor: TensorProxy) -> int | None: + numel = getattr(tensor, "_numel", None) + return numel if isinstance(numel, int) else None + + def replace_args_with_alias_map( computation_trace: Trace, alias_tensor_indices: list[list[int]], @@ -69,9 +74,11 @@ def replace_args_with_alias_map( arg = flat_args[indices[0]] for idx in filter(lambda idx: idx < len(flat_args), indices[1:]): arg_to_replace = flat_args[idx] - # Skip aliases with different numel (e.g., complex tensor and its real view) + # Skip aliases with different statically-known numel (e.g., complex tensor and its real view) # These share storage but have incompatible element counts - if arg.numel != arg_to_replace.numel: + arg_numel = _static_numel(arg) + arg_to_replace_numel = _static_numel(arg_to_replace) + if arg_numel is not None and arg_to_replace_numel is not None and arg_numel != arg_to_replace_numel: continue reshaped_arg = arg if arg_to_replace.shape != arg.shape: diff --git a/thunder/core/utils.py b/thunder/core/utils.py index 40b3e0eb13..aa592d183c 100644 --- a/thunder/core/utils.py +++ b/thunder/core/utils.py @@ -1033,20 +1033,8 @@ def producers(trace_or_bsyms: TraceCtx | list[BoundSymbolInterface], *, _map_to_ vout = variableify(out) - # Checks if the proxy was also an input (in which case this is not its producers) - is_input: bool = False - for vin in bsym.flat_variableified_proxy_args: - if vin == vout: - is_input = True - break - - if is_input: - continue - - if _map_to_numbers: - producers[out] = idx - else: - producers[out] = bsym + if vout not in bsym.flat_variableified_proxy_args: + producers[out] = idx if _map_to_numbers else bsym return producers diff --git a/thunder/executors/data_dependent_partition.py b/thunder/executors/data_dependent_partition.py index aca064125c..c9f009f2ff 100644 --- a/thunder/executors/data_dependent_partition.py +++ b/thunder/executors/data_dependent_partition.py @@ -95,6 +95,11 @@ def __init__(self, trace: TraceCtx): producers = utils.producers(trace, _map_to_numbers=True) consumers = utils.consumers(trace, _map_to_numbers=True) + trace_input_vars = set() + flat_inputs, _ = utils.tree_flatten((trace.args, trace.kwargs)) + for obj in flat_inputs: + if isinstance(obj, Proxy): + trace_input_vars.add(variableify(obj)) # Note, even though BoundSymbolInterface is hashable, it's hash is very slow # as it appears to be far off from being universal. @@ -117,6 +122,12 @@ def __init__(self, trace: TraceCtx): if not isinstance(inp, Proxy): continue + producer_id = producers.get(inp, None) + if producer_id is None and variableify(inp) in trace_input_vars: + # TODO: This should not happen + # I observed (i23,) = prims.shape(l_cache_position_) hits this branch + # check(False, "Unpacked trace input consumed") + continue producer_id = producers[inp] parent = bsym_id_to_node_map[producer_id] node.parents.add(parent) diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index b65d6944e0..7c39b78a7d 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -1045,10 +1045,11 @@ def full( ) -> Any: nv_fill_value = getnv(fill_value, fd, lc_to_nv_map) nvdtype = lcdtype_to_nvdtype(dtype) + nv_shape = getnv(shape, fd, lc_to_nv_map) _select_device(fd, device) - return fd.ops.full(shape, nv_fill_value, nvdtype) + return fd.ops.full(nv_shape, nv_fill_value, nvdtype) register_supported(PrimIDs.FULL, full, _full_check) diff --git a/thunder/executors/utils.py b/thunder/executors/utils.py index 6c12685329..12cd154006 100644 --- a/thunder/executors/utils.py +++ b/thunder/executors/utils.py @@ -50,9 +50,13 @@ def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]): for bsym in self.bound_symbols: flatouts = bsym.flat_outs - produces.update( - variableify(x) for x in flatouts if isinstance(x, Proxy) and producers[x] in self.bound_symbols - ) + for x in flatouts: + if not isinstance(x, Proxy): + continue + producer_bsym = producers.get(x, None) + # TODO: x should have a producer + if producer_bsym in self.bound_symbols: + produces.add(variableify(x)) # Short-circuits if the symbol is a comment, because comments don't consume anything # Note that comments may produce things @@ -69,7 +73,9 @@ def __init__(self, producers, consumers, bound_symbols: list[BoundSymbol]): for x in consumes: x = unvariableify(x) - if producers[x] not in self.bound_symbols: + # TODO: x should have a producer + producer_bsym = producers.get(x, None) + if producer_bsym not in self.bound_symbols: inputs.add(variableify(x)) # Outputs are things this produces that are consumed after it