Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions vllm_ascend/ops/triton/mamba/casual_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
PAD_SLOT_ID = -1


def causal_conv1d_ref(
def causal_conv1d_fn_native(
x: torch.Tensor,
weight: torch.Tensor,
bias: Optional[torch.Tensor] = None,
Expand All @@ -32,6 +32,7 @@ def causal_conv1d_ref(
bias: (dim,)
initial_states: (batch, dim, width - 1)
final_states_out: (batch, dim, width - 1)

out: (batch, dim, seqlen)
"""
if activation not in [None, "silu", "swish"]:
Expand All @@ -42,20 +43,17 @@ def causal_conv1d_ref(
dim, width = weight.shape

if initial_states is None:
out = F.conv1d(x,
weight.unsqueeze(1),
bias,
padding=width - 1,
groups=dim)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
else:
x = torch.cat([initial_states, x], dim=-1)
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
out = out[..., :seqlen]
if return_final_states:
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
dtype_in) # (batch, dim, width - 1)
dtype_in
) # (batch, dim, width - 1)
if final_states_out is not None:
final_states_out[..., :(width - 1)].copy_(final_states)
final_states_out.copy_(final_states)
else:
final_states_out = final_states
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
Expand Down Expand Up @@ -109,24 +107,22 @@ def causal_conv1d_fn(

out_ref = []
out_ref_b = []
seqlens = query_start_loc[1:] - query_start_loc[:-1]
seqlens = seqlens.tolist()
splits = torch.split(x, seqlens, dim=-1)

for i in range(len(seqlens)):
x_s = splits[i]
if cache_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
initial_states=conv_states[cache_indices[i]]
if has_initial_state[i] else None))
out_ref_b.append(
causal_conv1d_fn_native(
x,
weight,
bias,
activation=activation,
return_final_states=True,
final_states_out=conv_states[cache_indices[0]].unsqueeze(0),
initial_states=(
conv_states[cache_indices[0]].unsqueeze(0)
if has_initial_state[0]
else None
),
)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The refactoring of this function to improve performance has unfortunately broken the batch processing capability. The original implementation iterated over sequences in a batch, which is necessary for handling variable-length sequences packed together. The new implementation removes this loop and only processes the first element of the batch (as evidenced by the use of cache_indices[0] and has_initial_state[0]). This will result in incorrect behavior for any input with a batch size greater than one.

While this change fixes a shape mismatch for initial_states by adding .unsqueeze(0), removing the loop is incorrect. The correct approach is to restore the loop for batch processing and apply the unsqueeze(0) fix within the loop. Here is a suggested implementation:

    seqlens = query_start_loc[1:] - query_start_loc[:-1]
    seqlens = seqlens.tolist()
    splits = torch.split(x, seqlens, dim=-1)

    for i in range(len(seqlens)):
        x_s = splits[i]
        if cache_indices[i] == PAD_SLOT_ID:
            continue
        out_ref_b.append(
            causal_conv1d_fn_native(
                x_s,
                weight,
                bias,
                activation=activation,
                return_final_states=True,
                final_states_out=conv_states[cache_indices[i]].unsqueeze(0),
                initial_states=(
                    conv_states[cache_indices[i]].unsqueeze(0)
                    if has_initial_state[i]
                    else None
                ),
            )
        )


out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=-1))
out_ref_tensor = torch.cat(out_ref, dim=0)
return out_ref_tensor
Expand Down
Loading