Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion csrc/xpu/attention_xpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,4 +1254,71 @@ void paged_attention_v2(
query.scalar_type(), "paged_attention_xpu_v2_impl", [&] {
CALL_V2_LAUNCHER_BLOCK_SIZE(scalar_t);
});
}
}


constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }

void advance_step_ipex(int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
torch::Tensor& seq_lens, // type: int
torch::Tensor& slot_mapping, // type: long
torch::Tensor& block_tables) {
// std::cout << "advance step ipex get called!!!!!!" << std::endl;
sycl::queue& queue = vllm::xpu::vllmGetQueue();
// TODO: we might want to adjust this value
int num_blocks = 1024;
int num_threads = 32;
long* input_tokens_ptr = reinterpret_cast<long*>(input_tokens.data_ptr());
long const* sampled_token_ids_ptr = reinterpret_cast<long const*>(sampled_token_ids.data_ptr());
long* input_positions_ptr = reinterpret_cast<long*>(input_positions.data_ptr());
int* seq_lens_ptr = reinterpret_cast<int*>(seq_lens.data_ptr());
long* slot_mapping_ptr = reinterpret_cast<long*>(slot_mapping.data_ptr());
int const* block_tables_ptr = reinterpret_cast<int const*>(block_tables.data_ptr());
int64_t const block_tables_stride = block_tables.stride(0);
sycl::range<1> grid(num_blocks);
sycl::range<1> block(num_threads);
queue.submit([&](sycl::handler & cgh){
cgh.parallel_for(
sycl::nd_range<1>(grid * block, block),
[=](sycl::nd_item<1> item_ct1){
//constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
int num_query_blocks = div_ceil(num_queries, num_threads);

int group = item_ct1.get_group(0);

if (group >= num_query_blocks) {
return;
}

int cur_query_id = group * num_threads + item_ct1.get_local_id(0);

if (cur_query_id >= num_queries) {
return;
}

input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
int seq_len = seq_lens_ptr[cur_query_id];
int next_seq_len = seq_len + 1;
int next_input_pos = next_seq_len - 1;

// Update seq_lens
seq_lens_ptr[cur_query_id] = next_seq_len;
// Update input_positions
input_positions_ptr[cur_query_id] = next_input_pos;

int const* seq_block_tables_ptr =
block_tables_ptr + block_tables_stride * cur_query_id;
int block_index = next_input_pos / block_size;
int block_offset = next_input_pos % block_size;

int slot_num =
seq_block_tables_ptr[block_index] * block_size + block_offset;
// Update slot_mapping
slot_mapping_ptr[cur_query_id] = slot_num;
}
);
});
}
6 changes: 6 additions & 0 deletions csrc/xpu/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"reshape_and_cache",
&reshape_and_cache,
"Reshape the key and value tensors and cache them");

ops.def(
"advance_step_ipex",
&advance_step_ipex,
"Advance steps function used in multi-steps scheduler"
);

// Quant
ops.def(
Expand Down
9 changes: 9 additions & 0 deletions csrc/xpu/xpu_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ torch::Tensor marlin_gemm(
TORCH_CHECK(false, "marlin_gemm is not supported on XPU.");
}


void advance_step_ipex(int num_seqs, int num_queries, int block_size,
torch::Tensor& input_tokens, // type: long
torch::Tensor& sampled_token_ids, // type: long
torch::Tensor& input_positions, // type: long
torch::Tensor& seq_lens, // type: int
torch::Tensor& slot_mapping, // type: long
torch::Tensor& block_tables);

torch::Tensor awq_dequantize(torch::Tensor _kernel,
torch::Tensor _scaling_factors,
torch::Tensor _zeros,
Expand Down
1 change: 0 additions & 1 deletion vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,6 @@ def __init__(
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
self.priority = priority

self.cached_request_output = None

@property
Expand Down
24 changes: 23 additions & 1 deletion vllm/worker/xpu_multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,25 @@ def _advance_step(self, model_input: XPUStatefulModelInput,

attn_metadata = frozen_model_input.attn_metadata
assert isinstance(attn_metadata, IpexAttnMetadata)
# Add one to self.seq_lens
attn_metadata.advance_step(num_seqs, num_queries)
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids

# cloned_input_tokens = frozen_model_input.input_tokens.clone()
# cloned_sampled_token_ids = sampled_token_ids.clone()
# cloned_input_positions = frozen_model_input.input_positions.clone()
# cloned_seq_lens = attn_metadata.seq_lens_tensor.clone()
# cloned_slot_mappings = attn_metadata.slot_mapping.clone()
# cloned_block_tables = attn_metadata.block_tables.clone()

############### New implementation ##############################
# import vllm._C.ops
# vllm._C.ops.advance_step_ipex(num_seqs, num_queries, self.block_size, frozen_model_input.input_tokens, sampled_token_ids, frozen_model_input.input_positions, attn_metadata.seq_lens_tensor, attn_metadata.slot_mapping, attn_metadata.block_tables)
# torch.xpu.synchronize()
# vllm._C.ops.advance_step_ipex(num_seqs, num_queries, self.block_size, cloned_input_tokens, cloned_sampled_token_ids, cloned_input_positions, cloned_seq_lens, cloned_slot_mappings, cloned_block_tables)

# refer ops.advance_step()
##################### Original implementation ###################
next_seq_len = attn_metadata.seq_lens_tensor + 1
next_input_pos = next_seq_len - 1
attn_metadata.seq_lens_tensor = next_seq_len
Expand All @@ -486,7 +502,6 @@ def _advance_step(self, model_input: XPUStatefulModelInput,
attn_metadata.slot_mapping = slot_num.to(dtype=torch.long)

tmp_input_tokens = frozen_model_input.input_tokens
sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids
if sampled_token_ids.dim() > 1 and sampled_token_ids.size(-1) == 1:
sampled_token_ids = sampled_token_ids.squeeze(-1)
tmp_input_tokens[:num_queries] = sampled_token_ids[:num_queries]
Expand All @@ -498,13 +513,20 @@ def _advance_step(self, model_input: XPUStatefulModelInput,
input_positions=tmp_input_positions,
)

# Reset seq_lens
if frozen_model_input.seq_lens is not None:
tmp_seq_lens = frozen_model_input.seq_lens
tmp_seq_lens[:num_queries] = attn_metadata.seq_lens[:num_queries]
frozen_model_input = dataclasses.replace(
frozen_model_input,
seq_lens=tmp_seq_lens,
)
# assert torch.equal(frozen_model_input.input_tokens, cloned_input_tokens)
# assert torch.equal(frozen_model_input.input_positions, cloned_input_positions)
# assert torch.equal(attn_metadata.slot_mapping, cloned_slot_mappings)
# assert torch.equal(attn_metadata.seq_lens_tensor, cloned_seq_lens)

# print("All checked passed")

return model_input

Expand Down