Skip to content

Symbolic Values support for benchmark_inference.py #2677

@kshitij12345

Description

@kshitij12345

While running python thunder/benchmarks/benchmark_inference.py --input-length 32 --o utput-length 3 --mode thunder --num-iterations 10 , one of the FXGraph we receive is the following.

FXGraph

class GraphModule(torch.nn.Module):
    def forward(self, view_as_real_1: "f32[1, 1, 8, 64, 2]", xq_out: "f32[1, 1, 40, 128]", query_states: "bf16[1, 1, 40, 128]", key_states: "bf16[1, 1, 8, 128]", value_states: "bf16[1, 8, 1, 128]", l_past_key_values_cumulative_length_0_: "Sym(s78)", k_out: "bf16[1, 8, 35, 128]", l_cache_position_: "i64[1]", v_out: "bf16[1, 8, 35, 128]", l_stack0_0_: "bf16[1, 1, 1, 8192]", l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_: "bf16[5120, 5120]", l_inputs_embeds_: "bf16[1, 1, 5120]", l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_: "bf16[5120]", l_self_modules_layers_modules_0_modules_feed_forward_modules_gate_proj_parameters_weight_: "bf16[16384, 5120]", l_self_modules_layers_modules_0_modules_feed_forward_modules_up_proj_parameters_weight_: "bf16[16384, 5120]", l_self_modules_layers_modules_0_modules_feed_forward_modules_down_proj_parameters_weight_: "bf16[5120, 16384]", l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_: "bf16[5120]", l_self_modules_layers_modules_1_modules_self_attn_modules_q_proj_parameters_weight_: "bf16[5120, 5120]", l_self_modules_layers_modules_1_modules_self_attn_modules_k_proj_parameters_weight_: "bf16[1024, 5120]", l_self_modules_layers_modules_1_modules_self_attn_modules_v_proj_parameters_weight_: "bf16[1024, 5120]", freqs_cis_1: "c64[1, 1, 64]"):
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:218 in apply_rotary_emb, code: xk_out = torch.view_as_real(xk_ * freqs_cis[:, :, None, :]).flatten(3)
        xk_out: "f32[1, 1, 8, 128]" = view_as_real_1.flatten(3);  view_as_real_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:219 in apply_rotary_emb, code: return xq_out.type_as(xq), xk_out.type_as(xk)
        query_states_1: "bf16[1, 1, 40, 128]" = xq_out.type_as(query_states);  xq_out = query_states = None
        key_states_1: "bf16[1, 1, 8, 128]" = xk_out.type_as(key_states);  xk_out = key_states = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:354 in forward, code: query_states = query_states.transpose(1, 2)
        query_states_2: "bf16[1, 40, 1, 128]" = query_states_1.transpose(1, 2);  query_states_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:355 in forward, code: key_states = key_states.transpose(1, 2)
        key_states_2: "bf16[1, 8, 1, 128]" = key_states_1.transpose(1, 2);  key_states_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py:1999 in update, code: key_states = key_states.to(k_out.dtype)
        key_states_3: "bf16[1, 8, 1, 128]" = key_states_2.to(torch.bfloat16);  key_states_2 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py:2000 in update, code: value_states = value_states.to(v_out.dtype)
        value_states_1: "bf16[1, 8, 1, 128]" = value_states.to(torch.bfloat16);  value_states = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py:1946 in _sliding_update, code: self.cumulative_length[layer_idx] += key_states.shape[-2]
        add_1: "Sym(s78 + 1)" = l_past_key_values_cumulative_length_0_ + 1;  l_past_key_values_cumulative_length_0_ = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py:1967 in _sliding_update, code: self.key_cache[layer_idx].index_copy_(2, cache_position, key_states)
        index_copy_: "bf16[1, 8, 35, 128]" = k_out.index_copy_(2, l_cache_position_, key_states_3);  key_states_3 = index_copy_ = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/cache_utils.py:1968 in _sliding_update, code: self.value_cache[layer_idx].index_copy_(2, cache_position, value_states)
        index_copy__1: "bf16[1, 8, 35, 128]" = v_out.index_copy_(2, l_cache_position_, value_states_1);  l_cache_position_ = value_states_1 = index_copy__1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:230 in repeat_kv, code: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
        getitem_4: "bf16[1, 8, 1, 35, 128]" = k_out[(slice(None, None, None), slice(None, None, None), None, slice(None, None, None), slice(None, None, None))];  k_out = None
        hidden_states_1: "bf16[1, 8, 5, 35, 128]" = getitem_4.expand(1, 8, 5, 35, 128);  getitem_4 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:231 in repeat_kv, code: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
        key_states_4: "bf16[1, 40, 35, 128]" = hidden_states_1.reshape(1, 40, 35, 128);  hidden_states_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:230 in repeat_kv, code: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
        getitem_5: "bf16[1, 8, 1, 35, 128]" = v_out[(slice(None, None, None), slice(None, None, None), None, slice(None, None, None), slice(None, None, None))];  v_out = None
        hidden_states_2: "bf16[1, 8, 5, 35, 128]" = getitem_5.expand(1, 8, 5, 35, 128);  getitem_5 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:231 in repeat_kv, code: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
        value_states_2: "bf16[1, 40, 35, 128]" = hidden_states_2.reshape(1, 40, 35, 128);  hidden_states_2 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:248 in eager_attention_forward, code: attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
        transpose_4: "bf16[1, 40, 128, 35]" = key_states_4.transpose(2, 3);  key_states_4 = None
        matmul_1: "bf16[1, 40, 1, 35]" = torch.matmul(query_states_2, transpose_4);  query_states_2 = transpose_4 = None
        attn_weights: "bf16[1, 40, 1, 35]" = matmul_1 * 0.08838834764831845;  matmul_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:250 in eager_attention_forward, code: causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
        causal_mask: "bf16[1, 1, 1, 35]" = l_stack0_0_[(slice(None, None, None), slice(None, None, None), slice(None, None, None), slice(None, 35, None))];  l_stack0_0_ = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:251 in eager_attention_forward, code: attn_weights = attn_weights + causal_mask
        attn_weights_1: "bf16[1, 40, 1, 35]" = attn_weights + causal_mask;  attn_weights = causal_mask = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:253 in eager_attention_forward, code: attn_weights = nn.functional.softmax(attn_weights, dim=-1)
        attn_weights_2: "bf16[1, 40, 1, 35]" = torch.nn.functional.softmax(attn_weights_1, dim = -1);  attn_weights_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:254 in eager_attention_forward, code: attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
        attn_weights_3: "bf16[1, 40, 1, 35]" = torch.nn.functional.dropout(attn_weights_2, p = 0.0, training = True);  attn_weights_2 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:255 in eager_attention_forward, code: attn_output = torch.matmul(attn_weights, value_states)
        attn_output: "bf16[1, 40, 1, 128]" = torch.matmul(attn_weights_3, value_states_2);  attn_weights_3 = value_states_2 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:256 in eager_attention_forward, code: attn_output = attn_output.transpose(1, 2).contiguous()
        transpose_5: "bf16[1, 1, 40, 128]" = attn_output.transpose(1, 2);  attn_output = None
        attn_output_1: "bf16[1, 1, 40, 128]" = transpose_5.contiguous();  transpose_5 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:382 in forward, code: attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        reshape_4: "bf16[1, 1, 5120]" = attn_output_1.reshape(1, 1, -1);  attn_output_1 = None
        attn_output_2: "bf16[1, 1, 5120]" = reshape_4.contiguous();  reshape_4 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:383 in forward, code: attn_output = self.o_proj(attn_output)
        attn_output_3: "bf16[1, 1, 5120]" = torch._C._nn.linear(attn_output_2, l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_, None);  attn_output_2 = l_self_modules_layers_modules_0_modules_self_attn_modules_o_proj_parameters_weight_ = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:437 in forward, code: hidden_states = residual + attention_states
        hidden_states_3: "bf16[1, 1, 5120]" = l_inputs_embeds_ + attn_output_3;  l_inputs_embeds_ = attn_output_3 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:127 in forward, code: output = self._norm(x.float()).type_as(x)
        float_6: "f32[1, 1, 5120]" = hidden_states_3.float()
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:124 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        pow_2: "f32[1, 1, 5120]" = float_6.pow(2)
        mean_1: "f32[1, 1, 1]" = pow_2.mean(-1, keepdim = True);  pow_2 = None
        add_5: "f32[1, 1, 1]" = mean_1 + 1e-05;  mean_1 = None
        rsqrt_1: "f32[1, 1, 1]" = torch.rsqrt(add_5);  add_5 = None
        mul_6: "f32[1, 1, 5120]" = float_6 * rsqrt_1;  float_6 = rsqrt_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:127 in forward, code: output = self._norm(x.float()).type_as(x)
        output_1: "bf16[1, 1, 5120]" = mul_6.type_as(hidden_states_3);  mul_6 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:128 in forward, code: return output * self.weight
        hidden_states_4: "bf16[1, 1, 5120]" = output_1 * l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_;  output_1 = l_self_modules_layers_modules_0_modules_post_attention_layernorm_parameters_weight_ = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:95 in forward, code: down_proj = self.activation_fn(self.gate_proj(x)) * self.up_proj(x)
        linear_4: "bf16[1, 1, 16384]" = torch._C._nn.linear(hidden_states_4, l_self_modules_layers_modules_0_modules_feed_forward_modules_gate_proj_parameters_weight_, None);  l_self_modules_layers_modules_0_modules_feed_forward_modules_gate_proj_parameters_weight_ = None
        silu: "bf16[1, 1, 16384]" = torch.nn.functional.silu(linear_4, inplace = False);  linear_4 = None
        linear_5: "bf16[1, 1, 16384]" = torch._C._nn.linear(hidden_states_4, l_self_modules_layers_modules_0_modules_feed_forward_modules_up_proj_parameters_weight_, None);  hidden_states_4 = l_self_modules_layers_modules_0_modules_feed_forward_modules_up_proj_parameters_weight_ = None
        down_proj: "bf16[1, 1, 16384]" = silu * linear_5;  silu = linear_5 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:96 in forward, code: return self.down_proj(down_proj)
        hidden_states_5: "bf16[1, 1, 5120]" = torch._C._nn.linear(down_proj, l_self_modules_layers_modules_0_modules_feed_forward_modules_down_proj_parameters_weight_, None);  down_proj = l_self_modules_layers_modules_0_modules_feed_forward_modules_down_proj_parameters_weight_ = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:448 in forward, code: hidden_states = residual + hidden_states.view(residual.shape)
        view_3: "bf16[1, 1, 5120]" = hidden_states_5.view((1, 1, 5120));  hidden_states_5 = None
        hidden_states_6: "bf16[1, 1, 5120]" = hidden_states_3 + view_3;  hidden_states_3 = view_3 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:127 in forward, code: output = self._norm(x.float()).type_as(x)
        float_7: "f32[1, 1, 5120]" = hidden_states_6.float()
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:124 in _norm, code: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        pow_3: "f32[1, 1, 5120]" = float_7.pow(2)
        mean_2: "f32[1, 1, 1]" = pow_3.mean(-1, keepdim = True);  pow_3 = None
        add_7: "f32[1, 1, 1]" = mean_2 + 1e-05;  mean_2 = None
        rsqrt_2: "f32[1, 1, 1]" = torch.rsqrt(add_7);  add_7 = None
        mul_9: "f32[1, 1, 5120]" = float_7 * rsqrt_2;  float_7 = rsqrt_2 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:127 in forward, code: output = self._norm(x.float()).type_as(x)
        output_2: "bf16[1, 1, 5120]" = mul_9.type_as(hidden_states_6);  mul_9 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:128 in forward, code: return output * self.weight
        hidden_states_7: "bf16[1, 1, 5120]" = output_2 * l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_;  output_2 = l_self_modules_layers_modules_1_modules_input_layernorm_parameters_weight_ = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:333 in forward, code: query_states = self.q_proj(hidden_states).view(hidden_shape)
        linear_7: "bf16[1, 1, 5120]" = torch._C._nn.linear(hidden_states_7, l_self_modules_layers_modules_1_modules_self_attn_modules_q_proj_parameters_weight_, None);  l_self_modules_layers_modules_1_modules_self_attn_modules_q_proj_parameters_weight_ = None
        query_states_3: "bf16[1, 1, 40, 128]" = linear_7.view((1, 1, -1, 128));  linear_7 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:334 in forward, code: key_states = self.k_proj(hidden_states).view(*input_shape, -1, self.head_dim)
        linear_8: "bf16[1, 1, 1024]" = torch._C._nn.linear(hidden_states_7, l_self_modules_layers_modules_1_modules_self_attn_modules_k_proj_parameters_weight_, None);  l_self_modules_layers_modules_1_modules_self_attn_modules_k_proj_parameters_weight_ = None
        key_states_5: "bf16[1, 1, 8, 128]" = linear_8.view(1, 1, -1, 128);  linear_8 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:335 in forward, code: value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        linear_9: "bf16[1, 1, 1024]" = torch._C._nn.linear(hidden_states_7, l_self_modules_layers_modules_1_modules_self_attn_modules_v_proj_parameters_weight_, None);  hidden_states_7 = l_self_modules_layers_modules_1_modules_self_attn_modules_v_proj_parameters_weight_ = None
        view_6: "bf16[1, 1, 8, 128]" = linear_9.view((1, 1, -1, 128));  linear_9 = None
        value_states_3: "bf16[1, 8, 1, 128]" = view_6.transpose(1, 2);  view_6 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:339 in forward, code: query_states, key_states, position_embeddings.to(query_states.device)
        to_4: "c64[1, 1, 64]" = freqs_cis_1.to(device(type='cuda', index=0));  freqs_cis_1 = None
        
         # File: /usr/local/lib/python3.12/dist-packages/transformers/models/llama4/modeling_llama4.py:215 in apply_rotary_emb, code: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
        float_8: "f32[1, 1, 40, 128]" = query_states_3.float()
        reshape_5: "f32[1, 1, 40, 64, 2]" = float_8.reshape(1, 1, 40, -1, 2);  float_8 = None
        return (reshape_5, key_states_5, to_4, query_states_3, value_states_3, hidden_states_6, add_1)

Note that l_past_key_values_cumulative_length_0_ is a SymInt and is expected to change every iteration. But with the default cache setting for thunder, recompilation is triggered every time this value changes.

Expected behaviour: We should be able to run benchmark_inference with Symbolic values without triggering the recompilation.

Patch to pass symbolic values as jit option

diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py
index ec4cab56..36e0d311 100644
--- a/thunder/benchmarks/benchmark_inference.py
+++ b/thunder/benchmarks/benchmark_inference.py
@@ -154,7 +154,7 @@ class InferenceBenchmarkConfig:
     mode: str
     disable_moe_replacement: bool
     profile: bool
-
+    cache: str | None
 
 @dataclass
 class InferenceMetrics:
@@ -287,6 +287,8 @@ class InferenceBenchmark:
                 self._mask_transform = SDPAMaskTransform()
             res["transforms"] = [self._mask_transform]
             res["executors"] = [self._mask_transform.get_executor(), *thunder.get_default_executors()]
+        if self.config.cache:
+            res["cache"] = self.config.cache
         return res
 
     def _compile_model(self, model):
@@ -677,6 +679,7 @@ Examples:
 
     parser.add_argument("--save-results", action="store_true", help="Save results to JSON file")
     parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results")
+    parser.add_argument("--cache", type=str, default=None, help="Cache option: no caching, same input, constant values,
 symbolic values")
 
     args = parser.parse_args()
     return args
@@ -710,6 +713,7 @@ def main():
         enable_nv_linear=args.enable_nv_linear,
         disable_moe_replacement=args.disable_moe_replacement,
         profile=args.profile,
+        cache=args.cache,
     )
     benchmark = InferenceBenchmark(config)

Currently, running symbolic values (python thunder/benchmarks/benchmark_inference.py --input-length 32 --output-length 3 --mode thunder --num-iterations 10 --cache "symbolic values") fails with

  File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 352, in translate_bound_symbol
    nvresults = translator(*bsym.args, **bsym.kwargs, fd=fd, lc_to_nv_map=lc_to_nv_map)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/pytorch/lightning-thunder/thunder/executors/nvfuserex_impl.py", line 1046, in full
    return fd.ops.full(shape, nv_fill_value, nvdtype)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

  0%|          | 0/10 [00:10<?, ?it/s]
Exception occurred: RuntimeError: Expected false . Unsupported iterable object type for define_vector! Index:0
Exception raised from define_vector_fn at /opt/pytorch/nvfuser/python/python_direct/ops.cpp:2391 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, long, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xfb (0x766f82e110dd in /opt/pytorch/nvfuser/python/nvfuser_direct/../build/libnvfuser_codegen.so)
frame #1: <unknown function> + 0xee003 (0x766f85a97003 in /opt/pytorch/nvfuser/python/nvfuser_direct/_C_DIRECT.cpython-312-x86_64-linux-gnu.so)

cc: @beverlylytle

cc @lantiga

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions