diff --git a/vllm_ascend/ops/triton/mamba/casual_conv1d.py b/vllm_ascend/ops/triton/mamba/casual_conv1d.py index bb8299237b3..be4c6053e9d 100644 --- a/vllm_ascend/ops/triton/mamba/casual_conv1d.py +++ b/vllm_ascend/ops/triton/mamba/casual_conv1d.py @@ -17,7 +17,7 @@ PAD_SLOT_ID = -1 -def causal_conv1d_ref( +def causal_conv1d_fn_native( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, @@ -32,6 +32,7 @@ def causal_conv1d_ref( bias: (dim,) initial_states: (batch, dim, width - 1) final_states_out: (batch, dim, width - 1) + out: (batch, dim, seqlen) """ if activation not in [None, "silu", "swish"]: @@ -42,11 +43,7 @@ def causal_conv1d_ref( dim, width = weight.shape if initial_states is None: - out = F.conv1d(x, - weight.unsqueeze(1), - bias, - padding=width - 1, - groups=dim) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) else: x = torch.cat([initial_states, x], dim=-1) out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) @@ -55,7 +52,7 @@ def causal_conv1d_ref( final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( dtype_in) # (batch, dim, width - 1) if final_states_out is not None: - final_states_out[..., :(width - 1)].copy_(final_states) + final_states_out.copy_(final_states) else: final_states_out = final_states out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) @@ -109,24 +106,18 @@ def causal_conv1d_fn( out_ref = [] out_ref_b = [] - seqlens = query_start_loc[1:] - query_start_loc[:-1] - seqlens = seqlens.tolist() - splits = torch.split(x, seqlens, dim=-1) - - for i in range(len(seqlens)): - x_s = splits[i] - if cache_indices[i] == PAD_SLOT_ID: - continue - out_ref_b.append( - causal_conv1d_ref( - x_s, - weight, - bias, - activation=activation, - return_final_states=True, - final_states_out=conv_states[cache_indices[i]].unsqueeze(0), - initial_states=conv_states[cache_indices[i]] - if has_initial_state[i] else None)) + out_ref_b.append( + causal_conv1d_fn_native( + x, + weight, + bias, + activation=activation, + return_final_states=True, + final_states_out=conv_states[cache_indices[0]].unsqueeze(0), + initial_states=(conv_states[cache_indices[0]].unsqueeze(0) + if has_initial_state[0] else None), + )) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1)) out_ref_tensor = torch.cat(out_ref, dim=0) return out_ref_tensor