-
Notifications
You must be signed in to change notification settings - Fork 108
Description
While running python thunder/benchmarks/benchmark_inference.py --input-length 32 --o utput-length 3 --mode thunder --num-iterations 10 , one of the FXGraph we receive is the following.
FXGraph
class GraphModule(torch.nn.Module):
def forward(self, view_as_real_1: "f32[1, 1, 8, 64, 2]", xq_out: "f32[1, 1, 40, 128]", query_states: "bf16[1, 1, 40, 128]", key_states: "bf16[1, 1, 8, 128]", value_states: "bf16[1, 8, 1, 128]", l_past_key_values_cumulative_length_0_: "Sym(s78)", k_out: "bf16[1, 8, 35, 128]", l_cache_position_: "i64[1]", v_out: "bf16[1, 8, 35, 128]", l_stack0_0_: "bf16[1, 1, 1, 8192]", l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_: "bf16[5120, 5120]", l_inputs_embeds_: "bf16[1, 1, 5120]", l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_: "bf16[5120]", l_self_modules_layers_modules_0_modules_feed_forward_modules_gate_proj_parameters_weight_: "bf16[16384, 5120]", l_self_modules_layers_modules_0_modules_feed_forward_modules_up_proj_parameters_weight_: "bf16[16384, 5120]", l_self_modules_layers_modules_0_modules_feed_forward_modules_down_proj_parameters_weight_: "bf16[5120, 16384]", l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_: "bf16[5120]", l_self_modules_layers_modules_1_modules_self_attn_modules_q_proj_parameters_weight_: "bf16[5120, 5120]", l_self_modules_layers_modules_1_modules_self_attn_modules_k_proj_parameters_weight_: "bf16[1024, 5120]", l_self_modules_layers_modules_1_modules_self_attn_modules_v_proj_parameters_weight_: "bf16[1024, 5120]", freqs_cis_1: "c64[1, 1, 64]"):
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:218 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
xk_out: "f32[1, 1, 8, 128]" = view_as_real_1.flatten(3); view_as_real_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:219 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
query_states_1: "bf16[1, 1, 40, 128]" = xq_out.type_as(query_states); xq_out = query_states = None
key_states_1: "bf16[1, 1, 8, 128]" = xk_out.type_as(key_states); xk_out = key_states = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:354 in forward, code: query_states = query_states.transpose(1, 2)
query_states_2: "bf16[1, 40, 1, 128]" = query_states_1.transpose(1, 2); query_states_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:355 in forward, code: key_states = key_states.transpose(1, 2)
key_states_2: "bf16[1, 8, 1, 128]" = key_states_1.transpose(1, 2); key_states_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py:1999 in update, code: key_states = key_states.to(k_out.dtype)
key_states_3: "bf16[1, 8, 1, 128]" = key_states_2.to(torch.bfloat16); key_states_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py:2000 in update, code: value_states = value_states.to(v_out.dtype)
value_states_1: "bf16[1, 8, 1, 128]" = value_states.to(torch.bfloat16); value_states = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py:1946 in _sliding_update, code: self.cumulative_length[layer_idx] += key_states.shape[-2]
add_1: "Sym(s78 + 1)" = l_past_key_values_cumulative_length_0_ + 1; l_past_key_values_cumulative_length_0_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py:1967 in _sliding_update, code: self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
index_copy_: "bf16[1, 8, 35, 128]" = k_out.index_copy_(2, l_cache_position_, key_states_3); key_states_3 = index_copy_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py:1968 in _sliding_update, code: self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
index_copy__1: "bf16[1, 8, 35, 128]" = v_out.index_copy_(2, l_cache_position_, value_states_1); l_cache_position_ = value_states_1 = index_copy__1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:230 in repeat_kv, code: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
getitem_4: "bf16[1, 8, 1, 35, 128]" = k_out[(slice(None, None, None), slice(None, None, None), None, slice(None, None, None), slice(None, None, None))]; k_out = None
hidden_states_1: "bf16[1, 8, 5, 35, 128]" = getitem_4.expand(1, 8, 5, 35, 128); getitem_4 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:231 in repeat_kv, code: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
key_states_4: "bf16[1, 40, 35, 128]" = hidden_states_1.reshape(1, 40, 35, 128); hidden_states_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:230 in repeat_kv, code: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
getitem_5: "bf16[1, 8, 1, 35, 128]" = v_out[(slice(None, None, None), slice(None, None, None), None, slice(None, None, None), slice(None, None, None))]; v_out = None
hidden_states_2: "bf16[1, 8, 5, 35, 128]" = getitem_5.expand(1, 8, 5, 35, 128); getitem_5 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:231 in repeat_kv, code: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
value_states_2: "bf16[1, 40, 35, 128]" = hidden_states_2.reshape(1, 40, 35, 128); hidden_states_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:248 in eager_attention_forward, code: attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
transpose_4: "bf16[1, 40, 128, 35]" = key_states_4.transpose(2, 3); key_states_4 = None
matmul_1: "bf16[1, 40, 1, 35]" = torch.matmul(query_states_2, transpose_4); query_states_2 = transpose_4 = None
attn_weights: "bf16[1, 40, 1, 35]" = matmul_1 * 0.08838834764831845; matmul_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:250 in eager_attention_forward, code: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
causal_mask: "bf16[1, 1, 1, 35]" = l_stack0_0_[(slice(None, None, None), slice(None, None, None), slice(None, None, None), slice(None, 35, None))]; l_stack0_0_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:251 in eager_attention_forward, code: attn_weights = attn_weights + causal_mask
attn_weights_1: "bf16[1, 40, 1, 35]" = attn_weights + causal_mask; attn_weights = causal_mask = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:253 in eager_attention_forward, code: attn_weights = nn.functional.softmax(attn_weights, dim=-1)
attn_weights_2: "bf16[1, 40, 1, 35]" = torch.nn.functional.softmax(attn_weights_1, dim = -1); attn_weights_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:254 in eager_attention_forward, code: attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
attn_weights_3: "bf16[1, 40, 1, 35]" = torch.nn.functional.dropout(attn_weights_2, p = 0.0, training = True); attn_weights_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:255 in eager_attention_forward, code: attn_output = torch.matmul(attn_weights, value_states)
attn_output: "bf16[1, 40, 1, 128]" = torch.matmul(attn_weights_3, value_states_2); attn_weights_3 = value_states_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:256 in eager_attention_forward, code: attn_output = attn_output.transpose(1, 2).contiguous()
transpose_5: "bf16[1, 1, 40, 128]" = attn_output.transpose(1, 2); attn_output = None
attn_output_1: "bf16[1, 1, 40, 128]" = transpose_5.contiguous(); transpose_5 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:382 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
reshape_4: "bf16[1, 1, 5120]" = attn_output_1.reshape(1, 1, -1); attn_output_1 = None
attn_output_2: "bf16[1, 1, 5120]" = reshape_4.contiguous(); reshape_4 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:383 in forward, code: attn_output = self.o_proj(attn_output)
attn_output_3: "bf16[1, 1, 5120]" = torch._C._nn.linear(attn_output_2, l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_, None); attn_output_2 = l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:437 in forward, code: hidden_states = residual + attention_states
hidden_states_3: "bf16[1, 1, 5120]" = l_inputs_embeds_ + attn_output_3; l_inputs_embeds_ = attn_output_3 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:127 in forward, code: output = self._norm(x.float()).type_as(x)
float_6: "f32[1, 1, 5120]" = hidden_states_3.float()
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:124 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
pow_2: "f32[1, 1, 5120]" = float_6.pow(2)
mean_1: "f32[1, 1, 1]" = pow_2.mean(-1, keepdim = True); pow_2 = None
add_5: "f32[1, 1, 1]" = mean_1 + 1e-05; mean_1 = None
rsqrt_1: "f32[1, 1, 1]" = torch.rsqrt(add_5); add_5 = None
mul_6: "f32[1, 1, 5120]" = float_6 * rsqrt_1; float_6 = rsqrt_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:127 in forward, code: output = self._norm(x.float()).type_as(x)
output_1: "bf16[1, 1, 5120]" = mul_6.type_as(hidden_states_3); mul_6 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:128 in forward, code: return output * self.weight
hidden_states_4: "bf16[1, 1, 5120]" = output_1 * l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_; output_1 = l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:95 in forward, code: down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x)
linear_4: "bf16[1, 1, 16384]" = torch._C._nn.linear(hidden_states_4, l_self_modules_layers_modules_0_modules_feed_forward_modules_gate_proj_parameters_weight_, None); l_self_modules_layers_modules_0_modules_feed_forward_modules_gate_proj_parameters_weight_ = None
silu: "bf16[1, 1, 16384]" = torch.nn.functional.silu(linear_4, inplace = False); linear_4 = None
linear_5: "bf16[1, 1, 16384]" = torch._C._nn.linear(hidden_states_4, l_self_modules_layers_modules_0_modules_feed_forward_modules_up_proj_parameters_weight_, None); hidden_states_4 = l_self_modules_layers_modules_0_modules_feed_forward_modules_up_proj_parameters_weight_ = None
down_proj: "bf16[1, 1, 16384]" = silu * linear_5; silu = linear_5 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:96 in forward, code: return self.down_proj(down_proj)
hidden_states_5: "bf16[1, 1, 5120]" = torch._C._nn.linear(down_proj, l_self_modules_layers_modules_0_modules_feed_forward_modules_down_proj_parameters_weight_, None); down_proj = l_self_modules_layers_modules_0_modules_feed_forward_modules_down_proj_parameters_weight_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:448 in forward, code: hidden_states = residual + hidden_states.view(residual.shape)
view_3: "bf16[1, 1, 5120]" = hidden_states_5.view((1, 1, 5120)); hidden_states_5 = None
hidden_states_6: "bf16[1, 1, 5120]" = hidden_states_3 + view_3; hidden_states_3 = view_3 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:127 in forward, code: output = self._norm(x.float()).type_as(x)
float_7: "f32[1, 1, 5120]" = hidden_states_6.float()
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:124 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
pow_3: "f32[1, 1, 5120]" = float_7.pow(2)
mean_2: "f32[1, 1, 1]" = pow_3.mean(-1, keepdim = True); pow_3 = None
add_7: "f32[1, 1, 1]" = mean_2 + 1e-05; mean_2 = None
rsqrt_2: "f32[1, 1, 1]" = torch.rsqrt(add_7); add_7 = None
mul_9: "f32[1, 1, 5120]" = float_7 * rsqrt_2; float_7 = rsqrt_2 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:127 in forward, code: output = self._norm(x.float()).type_as(x)
output_2: "bf16[1, 1, 5120]" = mul_9.type_as(hidden_states_6); mul_9 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:128 in forward, code: return output * self.weight
hidden_states_7: "bf16[1, 1, 5120]" = output_2 * l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_; output_2 = l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_ = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:333 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape)
linear_7: "bf16[1, 1, 5120]" = torch._C._nn.linear(hidden_states_7, l_self_modules_layers_modules_1_modules_self_attn_modules_q_proj_parameters_weight_, None); l_self_modules_layers_modules_1_modules_self_attn_modules_q_proj_parameters_weight_ = None
query_states_3: "bf16[1, 1, 40, 128]" = linear_7.view((1, 1, -1, 128)); linear_7 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:334 in forward, code: key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
linear_8: "bf16[1, 1, 1024]" = torch._C._nn.linear(hidden_states_7, l_self_modules_layers_modules_1_modules_self_attn_modules_k_proj_parameters_weight_, None); l_self_modules_layers_modules_1_modules_self_attn_modules_k_proj_parameters_weight_ = None
key_states_5: "bf16[1, 1, 8, 128]" = linear_8.view(1, 1, -1, 128); linear_8 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:335 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
linear_9: "bf16[1, 1, 1024]" = torch._C._nn.linear(hidden_states_7, l_self_modules_layers_modules_1_modules_self_attn_modules_v_proj_parameters_weight_, None); hidden_states_7 = l_self_modules_layers_modules_1_modules_self_attn_modules_v_proj_parameters_weight_ = None
view_6: "bf16[1, 1, 8, 128]" = linear_9.view((1, 1, -1, 128)); linear_9 = None
value_states_3: "bf16[1, 8, 1, 128]" = view_6.transpose(1, 2); view_6 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:339 in forward, code: query_states, key_states, position_embeddings.to(query_states.device)
to_4: "c64[1, 1, 64]" = freqs_cis_1.to(device(type='cuda', index=0)); freqs_cis_1 = None
# File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:215 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
float_8: "f32[1, 1, 40, 128]" = query_states_3.float()
reshape_5: "f32[1, 1, 40, 64, 2]" = float_8.reshape(1, 1, 40, -1, 2); float_8 = None
return (reshape_5, key_states_5, to_4, query_states_3, value_states_3, hidden_states_6, add_1)Note that l_past_key_values_cumulative_length_0_ is a SymInt and is expected to change every iteration. But with the default cache setting for thunder, recompilation is triggered every time this value changes.
Expected behaviour: We should be able to run benchmark_inference with Symbolic values without triggering the recompilation.
Patch to pass symbolic values as jit option
diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py
index ec4cab56..36e0d311 100644
--- a/thunder/benchmarks/benchmark_inference.py
+++ b/thunder/benchmarks/benchmark_inference.py
@@ -154,7 +154,7 @@ class InferenceBenchmarkConfig:
mode: str
disable_moe_replacement: bool
profile: bool
-
+ cache: str | None
@dataclass
class InferenceMetrics:
@@ -287,6 +287,8 @@ class InferenceBenchmark:
self._mask_transform = SDPAMaskTransform()
res["transforms"] = [self._mask_transform]
res["executors"] = [self._mask_transform.get_executor(), *thunder.get_default_executors()]
+ if self.config.cache:
+ res["cache"] = self.config.cache
return res
def _compile_model(self, model):
@@ -677,6 +679,7 @@ Examples:
parser.add_argument("--save-results", action="store_true", help="Save results to JSON file")
parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results")
+ parser.add_argument("--cache", type=str, default=None, help="Cache option: no caching, same input, constant values,
symbolic values")
args = parser.parse_args()
return args
@@ -710,6 +713,7 @@ def main():
enable_nv_linear=args.enable_nv_linear,
disable_moe_replacement=args.disable_moe_replacement,
profile=args.profile,
+ cache=args.cache,
)
benchmark = InferenceBenchmark(config)Currently, running symbolic values (python thunder/benchmarks/benchmark_inference.py --input-length 32 --output-length 3 --mode thunder --num-iterations 10 --cache "symbolic values") fails with
File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 352, in translate_bound_symbol
nvresults = translator(*bsym.args, **bsym.kwargs, fd=fd, lc_to_nv_map=lc_to_nv_map)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 1046, in full
return fd.ops.full(shape, nv_fill_value, nvdtype)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
0%| | 0/10 [00:10<?, ?it/s]
Exception occurred: RuntimeError: Expected false . Unsupported iterable object type for define_vector! Index:0
Exception raised from define_vector_fn at /opt/pytorch/nvfuser/python/python_direct/ops.cpp:2391 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, long, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfb (0x766f82e110dd in /opt/pytorch/nvfuser/python/nvfuser_direct/../build/libnvfuser_codegen.so)
frame #1: <unknown function> + 0xee003 (0x766f85a97003 in /opt/pytorch/nvfuser/python/nvfuser_direct/_C_DIRECT.cpython-312-x86_64-linux-gnu.so)
cc: @beverlylytle
cc @lantiga