Skip to content

Commit 307db82

Browse files
sgbihuwine99t-jankowski
authored
Add Group Query Attention support with OV base OPs (#28163)
### Details: - Try to enable LLM based on onnxruntime. (Phi3, Llama3 is working on CPU, Phi3 can work with iGPU) ### Test scripts ``` import onnxruntime as rt import os import numpy as np import time import onnxruntime.tools.add_openvino_win_libs as utils utils.add_openvino_libs_to_path() from transformers import PreTrainedTokenizerFast test_lama3 = False test_phi3 = True if test_phi3: modelPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'model.onnx') tokenizerPath = os.path.join('D:\\', 'models', 'llm', 'Phi-3-mini-4k-instruct-onnx', 'tokenizer.json') if test_lama3: modelPath = os.path.join('D:\\', 'models', 'llm', 'llama3.1-8B-instruct-onnx', 'model.onnx') so = rt.SessionOptions() # so.log_severity_level = 3 # sess = rt.InferenceSession(modelPath, so, providers=['CPUExecutionProvider']) sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU", 'cache_dir': "cache"}]) # sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "CPU"}]) # sess = rt.InferenceSession(modelPath, so, providers=['OpenVINOExecutionProvider'], provider_options=[{'device_type' : "NPU"}]) tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizerPath) # print(sess.get_device()) # for name in sess.get_inputs(): # print(f"Name: {name.name}, Shape: {name.shape}, Type: {name.type}") outputs = sess.get_outputs() output_names = list(map(lambda output: output.name, outputs)) # Assuming the model has 32 layers and each layer has a key and value state # Phi3 def get_phi3_param(): num_layers = 32 batch_size = 1 num_heads = 32 sequence_length = 2048 hidden_size = 96 return num_layers, batch_size, num_heads, sequence_length, hidden_size # lama def get_llama3_param(): num_layers = 32 batch_size = 1 num_heads = 8 sequence_length = 2048 hidden_size = 128 return num_layers, batch_size, num_heads, sequence_length, hidden_size if test_phi3: num_layers, batch_size, num_heads, sequence_length, hidden_size = get_phi3_param() if test_lama3: num_layers, batch_size, num_heads, sequence_length, hidden_size = get_llama3_param() # Initialize past_key_values with zeros cpu_array = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32) # print("Output names: ", outputs[0].type.data) def create_present_state_binding(binding, outputs): outputMap={} for output in outputs: shapes = [] for item in output.shape: if isinstance(item, str): if 'batch_size' in item: shapes.append(batch_size) elif 'sequence_length' in item: if output.name == 'logits': shapes.append(len(inputToken)) else: shapes.append(sequence_length) elif 'hidden_size' in item: shapes.append(hidden_size) elif 'num_heads' in item: shapes.append(num_heads) else: raise ValueError(f"Unknown dimension: {item}") else: shapes.append(item) present_state = rt.OrtValue.ortvalue_from_shape_and_type(shapes, np.float32) binding.bind_ortvalue_output(output.name, present_state) outputMap[output.name] = present_state return outputMap def rebind_inputs(lastOutput, binding): for index in range(num_layers): binding.bind_ortvalue_input(f'past_key_values.{index}.key', lastOutput[f'present.{index}.key']) binding.bind_ortvalue_input(f'past_key_values.{index}.value', lastOutput[f'present.{index}.value']) return binding def init_input_with_binding(binding): for index in range(num_layers): key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array) value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array) binding.bind_ortvalue_input(f'past_key_values.{index}.key', key_state) binding.bind_ortvalue_input(f'past_key_values.{index}.value', value_state) return binding def reinit_input_bindings(bindings, lastOutput): newOutput = create_present_state_binding(bindings, lastOutput) binding = rebind_inputs(lastOutput, bindings) return binding, newOutput def create_numpy_inputs(inputToken): tokenLen = len(inputToken) npinput_ids = np.array([inputToken], dtype=np.int64) npattention_mask = np.array([[1] * (tokenLen)], dtype=np.int64) return npinput_ids, npattention_mask def init_ortinput(inputToken): flattened_past_key_values = {} for index in range(num_layers): key_state = rt.OrtValue.ortvalue_from_numpy(cpu_array) value_state = rt.OrtValue.ortvalue_from_numpy(cpu_array) flattened_past_key_values[f'past_key_values.{index}.key'] = key_state flattened_past_key_values[f'past_key_values.{index}.value'] = value_state ids, mask = create_numpy_inputs(inputToken) flattened_past_key_values['input_ids'] = rt.OrtValue.ortvalue_from_numpy(ids) flattened_past_key_values['attention_mask'] = rt.OrtValue.ortvalue_from_numpy(mask) return flattened_past_key_values def init_npinput(inputToken): flattened_past_key_values = {} for index in range(num_layers): key_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32) value_state = np.zeros((batch_size, num_heads, sequence_length, hidden_size), dtype=np.float32) flattened_past_key_values[f'past_key_values.{index}.key'] = key_state flattened_past_key_values[f'past_key_values.{index}.value'] = value_state flattened_past_key_values['input_ids'], flattened_past_key_values['attention_mask'] = create_numpy_inputs(inputToken) return flattened_past_key_values def init_bindinginput(inputToken): binding = sess.io_binding() binding = init_input_with_binding(binding) ids, mask = create_numpy_inputs(inputToken) binding.bind_ortvalue_input(f'attention_mask', rt.OrtValue.ortvalue_from_numpy(mask)) binding.bind_ortvalue_input(f'input_ids', rt.OrtValue.ortvalue_from_numpy(ids)) return binding # Question # The Sun is yellow because # Phi3 if test_phi3: # 450 8991 5692 # inputToken = [32010, 29871, 13] inputToken = [32010, 29871, 13, 1576, 8991, 338, 13328, 1363, 29871, 32007, 13, 32001] # inputToken = [32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010, 32010] # lama3 if test_lama3: # 315 1202 7479 inputToken = [128000, 27, 91, 882, 91, 397, 791, 8219, 374, 14071, 1606, 83739, 408, 91, 397, 27, 91, 78191, 91, 29] # inputToken = [315] history_tokens = inputToken flattened_past_key_values = init_npinput(inputToken) # flattened_past_key_values = init_ortinput(inputToken) # binding = init_bindinginput(inputToken) # lastoutput = create_present_state_binding(binding, outputs) lastTokenLen = len(inputToken) # roption = rt.RunOptions() # roption.add_run_config_entry("gpu_graph_id", "-1") before = time.time() results = sess.run(output_names, flattened_past_key_values) # results = sess.run_with_iobinding(binding) # results = sess.run_with_ort_values(output_names, flattened_past_key_values) after = time.time() print("Time cost in ms: ", (after - before) * 1000) # print(np.argmax(results[0].numpy(), axis=-1)[-1]) print(np.argmax(results[0], axis=-1)[-1]) # print(results[0]) # print(output_names[1]) # print(results[1][0][0][0]) # print(results[1][0][0][1]) # print(results[1][0][0][2]) # # print(results[1][0][0][14]) # # print(results[1]) # print(output_names[2]) # # print(results[2]) # print(results[2][0][0][0]) # print(results[2][0][0][1]) # print(results[2][0][0][2]) # print(results[2][0][0][14]) # inputToken.append(450) # rebind_inputs(lastOutput, binding) def update_kvcache(inputsMap, results): for index in range(len(output_names)): if not output_names[index].startswith('present'): continue # print(f'{output_names[index]}: {results[index].shape}') outputname = output_names[index] inputname = outputname.replace('present', 'past_key_values') inputsMap[inputname] = results[index] return inputsMap # lastOutput = create_present_state_binding(binding, sess.get_outputs()) # flattened_past_key_values = update_kvcache(flattened_past_key_values, results) for index in range(len(output_names)): if not output_names[index].startswith('present'): continue # print(f'{output_names[index]}: {results[index].shape}') outputname = output_names[index] inputname = outputname.replace('present', 'past_key_values') flattened_past_key_values[inputname] = results[index] if test_phi3: inputToken = [450] if test_lama3: inputToken = [315] history_tokens += inputToken npinput_ids = np.array([inputToken], dtype=np.int64) npattention_mask = np.array([[1] * (lastTokenLen+1)], dtype=np.int64) print(f"lastTokenLen:{lastTokenLen}") # attention_mask = rt.OrtValue.ortvalue_from_numpy(npattention_mask) # input_ids = rt.OrtValue.ortvalue_from_numpy(npinput_ids) # binding.bind_ortvalue_input(f'attention_mask', attention_mask) # binding.bind_ortvalue_input(f'input_ids', input_ids) # flattened_past_key_values[f'attention_mask'].update_inplace(npattention_mask) # flattened_past_key_values[f'input_ids'].update_inplace(npinput_ids) # flattened_past_key_values[f'attention_mask'] = attention_mask # flattened_past_key_values[f'input_ids'] = input_ids flattened_past_key_values[f'attention_mask'] = npattention_mask flattened_past_key_values[f'input_ids'] = npinput_ids # print(flattened_past_key_values) before = time.time() results = sess.run(output_names, flattened_past_key_values) # results = sess.run_with_iobinding(binding) # results = sess.run_with_ort_values(output_names, flattened_past_key_values) after = time.time() print("Time cost in ms: ", (after - before) * 1000) # Results: [np.int32(450), np.int32(8991), np.int32(5692), np.int32(13328), np.int32(304), np.int32(502), np.int32(19434), np.int32(2861), np.int32(304), np.int32(9596), np.int32(280), np.int32(1141), np.int32(14801), np.int32(292), np.int32(29889), np.int32(1932), np.int32(6575), np.int32(4366), np.int32(14517), np.int32(1549), np.int32(278), np.int32(11563), np.int32(29915), np.int32(29879), np.int32(25005), np.int32(29892), np.int32(278), np.int32(20511), np.int32(7254), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(526), np.int32(29574), np.int32(297), np.int32(599), np.int32(18112), np.int32(491), np.int32(278), np.int32(330), np.int32(2129), np.int32(322), np.int32(17105), np.int32(297), np.int32(278), np.int32(4799), np.int32(29889), np.int32(910), np.int32(14801), np.int32(292), np.int32(9946), np.int32(278), np.int32(14744), np.int32(304), np.int32(1106), np.int32(7254), np.int32(29889), np.int32(2398), np.int32(29892), np.int32(278), np.int32(5520), np.int32(2654), np.int32(322), np.int32(13328), np.int32(281), np.int32(6447), np.int32(1477), np.int32(29879), np.int32(1209), np.int32(1549), np.int32(278), np.int32(25005), np.int32(901), np.int32(5948), np.int32(322), np.int32(526), np.int32(3109), np.int32(29574), np.int32(29889), np.int32(1932), np.int32(591), np.int32(1106), np.int32(472), np.int32(278), np.int32(8991), np.int32(29892), np.int32(591), np.int32(1074), np.int32(372), np.int32(408), np.int32(263), np.int32(13328), np.int32(470), np.int32(24841), np.int32(8086), np.int32(1363), np.int32(278), np.int32(7254), np.int32(3578), np.int32(338), np.int32(29574), np.int32(714), np.int32(310), np.int32(1749), np.int32(1196), np.int32(310), np.int32(11126), np.int32(29892), np.int32(322), np.int32(278), np.int32(9886), np.int32(3578), np.int32(393), np.int32(22170), np.int32(1749), np.int32(5076), np.int32(338), np.int32(758), np.int32(24130), np.int32(10835), np.int32(13328), np.int32(322), np.int32(2654), np.int32(29889), np.int32(32000)] # index = 0 # for result in results: # print(f'{output_names[index]}: {result.shape}, {result.dtype}') # index += 1 print(np.argmax(results[0], axis=-1)[-1]) # print(np.argmax(results[0].numpy(), axis=-1)[-1]) # golden results # Time cost in ms: 1255.2332878112793 # [30751 13 13 1494 1731 263 29889 372 13 24380 13 450] # lastTokenLen:12 # Time cost in ms: 1006.781816482544 # [8991] last_generated_token = np.argmax(results[0], axis=-1)[-1][-1] history_tokens.append(last_generated_token) NUM_INFERENCE = 15 for i in range(NUM_INFERENCE): # update kvcahe for index in range(len(output_names)): if not output_names[index].startswith('present'): continue # print(f'{output_names[index]}: {results[index].shape}') outputname = output_names[index] inputname = outputname.replace('present', 'past_key_values') flattened_past_key_values[inputname] = results[index] # update input token flattened_past_key_values[f'input_ids'] = np.array([[last_generated_token]], dtype=np.int64) flattened_past_key_values[f'attention_mask'] = np.array([[1] * len(history_tokens)], dtype=np.int64) before = time.time() results = sess.run(output_names, flattened_past_key_values) after = time.time() print("Time cost in ms: ", (after - before) * 1000) last_generated_token = np.argmax(results[0], axis=-1)[-1][-1] history_tokens.append(last_generated_token) print(tokenizer.decode(history_tokens)) ``` ### Tickets: - related to 155287, 157123 --------- Co-authored-by: Yu, Zijun <[email protected]> Co-authored-by: Tomasz Jankowski <[email protected]>
1 parent 622a5dd commit 307db82

File tree

9 files changed

+1305
-0
lines changed

9 files changed

+1305
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "openvino/op/group_query_attention.hpp"
8+
#include "openvino/op/shape_of.hpp"
9+
#include "openvino/pass/matcher_pass.hpp"
10+
#include "transformations_visibility.hpp"
11+
12+
namespace ov {
13+
namespace pass {
14+
15+
class TRANSFORMATIONS_API GroupQueryAttentionDecomposition;
16+
17+
} // namespace pass
18+
} // namespace ov
19+
20+
class ov::pass::GroupQueryAttentionDecomposition : public ov::pass::MatcherPass {
21+
public:
22+
OPENVINO_MATCHER_PASS_RTTI("GroupQueryAttentionDecomposition");
23+
GroupQueryAttentionDecomposition();
24+
25+
private:
26+
ov::OutputVector decompose(std::shared_ptr<ov::op::internal::GroupQueryAttention> node);
27+
std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<op::v3::ShapeOf>& shape,
28+
const std::vector<int>& dims);
29+
std::shared_ptr<ov::Node> get_dimensions(const std::shared_ptr<ov::Node>& node, const std::vector<int>& dims);
30+
ov::OutputVector make_split(const ov::Output<ov::Node>& value, int64_t num_splits, int64_t axis);
31+
std::shared_ptr<ov::Node> rotaryEmbedding(ov::Output<ov::Node> input,
32+
ov::Output<ov::Node> past_seqlen,
33+
std::shared_ptr<ov::Node> seqlen_k,
34+
std::shared_ptr<ov::Node> cos_cache,
35+
std::shared_ptr<ov::Node> sin_cache,
36+
std::shared_ptr<ov::Node> dim_head_size,
37+
bool interleaved);
38+
};

src/common/transformations/src/transformations/common_optimizations/common_optimizations.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
#include "transformations/op_conversions/eye_decomposition.hpp"
109109
#include "transformations/op_conversions/gelu7_downgrade.hpp"
110110
#include "transformations/op_conversions/group_normalization_decomposition.hpp"
111+
#include "transformations/op_conversions/group_query_attention_decomposition.hpp"
111112
#include "transformations/op_conversions/hsigmoid_decomposition.hpp"
112113
#include "transformations/op_conversions/hswish_decomposition.hpp"
113114
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
@@ -156,6 +157,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
156157
REGISTER_DISABLED_PASS(manager, ConvertInterpolate1ToInterpolate4)
157158

158159
auto decomp = manager.register_pass<GraphRewrite>();
160+
ADD_MATCHER(decomp, GroupQueryAttentionDecomposition)
159161
ADD_MATCHER(decomp, ScaledDotProductAttentionDecomposition)
160162
ADD_MATCHER(decomp, Gelu7Downgrade)
161163
ADD_MATCHER(decomp, BidirectionalSequenceDecomposition)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "transformations/op_conversions/group_query_attention_decomposition.hpp"
6+
7+
#include <memory>
8+
9+
#include "itt.hpp"
10+
#include "openvino/core/rt_info.hpp"
11+
#include "openvino/op/add.hpp"
12+
#include "openvino/op/concat.hpp"
13+
#include "openvino/op/constant.hpp"
14+
#include "openvino/op/convert.hpp"
15+
#include "openvino/op/gather.hpp"
16+
#include "openvino/op/greater.hpp"
17+
#include "openvino/op/multiply.hpp"
18+
#include "openvino/op/range.hpp"
19+
#include "openvino/op/reshape.hpp"
20+
#include "openvino/op/scaled_dot_product_attention.hpp"
21+
#include "openvino/op/select.hpp"
22+
#include "openvino/op/shape_of.hpp"
23+
#include "openvino/op/slice.hpp"
24+
#include "openvino/op/split.hpp"
25+
#include "openvino/op/subtract.hpp"
26+
#include "openvino/op/transpose.hpp"
27+
#include "openvino/op/unsqueeze.hpp"
28+
#include "openvino/pass/pattern/op/wrap_type.hpp"
29+
30+
ov::pass::GroupQueryAttentionDecomposition::GroupQueryAttentionDecomposition() {
31+
MATCHER_SCOPE(GroupQeuryAttentionDecomposition);
32+
auto pattern_node = ov::pass::pattern::wrap_type<ov::op::internal::GroupQueryAttention>();
33+
34+
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
35+
auto& pattern_to_output = m.get_pattern_value_map();
36+
auto node = ov::as_type_ptr<ov::op::internal::GroupQueryAttention>(
37+
pattern_to_output.at(pattern_node).get_node_shared_ptr());
38+
39+
if (node == nullptr || transformation_callback(node)) {
40+
return false;
41+
}
42+
43+
auto new_output_node = decompose(node);
44+
ov::replace_node(node, new_output_node);
45+
return true;
46+
};
47+
48+
auto m = std::make_shared<ov::pass::pattern::Matcher>(pattern_node, matcher_name);
49+
register_matcher(m, callback);
50+
}
51+
52+
ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(
53+
std::shared_ptr<ov::op::internal::GroupQueryAttention> node) {
54+
using namespace ov::op;
55+
56+
const auto num_heads = node->get_num_heads();
57+
const auto kv_num_heads = node->get_kv_num_heads();
58+
const auto scale = node->get_scale();
59+
const auto do_rotary = node->get_do_rotary();
60+
const auto rotary_interleaved = node->get_rotary_interleaved();
61+
// TODO: add softcap support
62+
63+
auto Q = node->input_value(0);
64+
auto K = node->input_value(1);
65+
auto V = node->input_value(2);
66+
auto past_key = node->input_value(3);
67+
auto past_value = node->input_value(4);
68+
auto seqlens_k = node->input_value(5);
69+
auto cos_cache = node->input_value(6);
70+
auto sin_cache = node->input_value(7);
71+
72+
// The length of all tokens (past + current) is `seqlens_k` + 1
73+
// current = Q.shape[2], past = `seqlens_k` + 1 - current
74+
75+
const auto T = Q.get_element_type();
76+
const auto q_shape = register_new_node<v3::ShapeOf>(Q);
77+
const auto current_sequence_length = get_dimensions(q_shape, {2});
78+
const auto head_size_node = get_dimensions(q_shape, {3});
79+
80+
auto zero = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {0}));
81+
auto one = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {1}));
82+
auto one_without_shape = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{}, {1}));
83+
auto two = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{1}, {2}));
84+
auto seqlens_elemi64 = register_new_node<v0::Convert>(seqlens_k, ov::element::i64);
85+
auto real_seqlens = register_new_node<v1::Add>(seqlens_elemi64, one);
86+
87+
// Only consider batch is 1
88+
auto seqlens_1d = register_new_node<v1::Reshape>(real_seqlens, one, false);
89+
auto past_sequence_length = register_new_node<v1::Subtract>(seqlens_1d, current_sequence_length);
90+
if (do_rotary) {
91+
Q = rotaryEmbedding(Q,
92+
past_sequence_length,
93+
seqlens_1d,
94+
cos_cache.get_node_shared_ptr(),
95+
sin_cache.get_node_shared_ptr(),
96+
head_size_node,
97+
rotary_interleaved);
98+
K = rotaryEmbedding(K,
99+
past_sequence_length,
100+
seqlens_1d,
101+
cos_cache.get_node_shared_ptr(),
102+
sin_cache.get_node_shared_ptr(),
103+
head_size_node,
104+
rotary_interleaved);
105+
}
106+
107+
auto construct_kv_cache = [&](const ov::Output<ov::Node>& past, const ov::Output<ov::Node>& current) {
108+
auto past_datas = register_new_node<v8::Slice>(past, zero, past_sequence_length, one, two);
109+
auto curr_datas = register_new_node<v8::Slice>(current, zero, current_sequence_length, one, two);
110+
return register_new_node<v0::Concat>(ov::NodeVector{past_datas, curr_datas}, 2);
111+
};
112+
K = construct_kv_cache(past_key, K);
113+
V = construct_kv_cache(past_value, V);
114+
auto present_k = K;
115+
auto present_v = V;
116+
117+
const size_t kv_num_heads_factor = num_heads / kv_num_heads;
118+
if (kv_num_heads_factor > 1) {
119+
const auto kv_shape = register_new_node<v3::ShapeOf>(K);
120+
const auto kv_shape_prev_2 = get_dimensions(kv_shape, {0, 1});
121+
const auto kv_shape_last_2 = get_dimensions(kv_shape, {2, 3});
122+
auto new_kv_shape = register_new_node<v0::Concat>(ov::NodeVector{kv_shape_prev_2, one, kv_shape_last_2}, 0);
123+
K = register_new_node<v1::Reshape>(K, new_kv_shape, false);
124+
V = register_new_node<v1::Reshape>(V, new_kv_shape, false);
125+
K = register_new_node<v0::Concat>(ov::OutputVector(kv_num_heads_factor, K), 2);
126+
V = register_new_node<v0::Concat>(ov::OutputVector(kv_num_heads_factor, V), 2);
127+
const auto q_shape = register_new_node<v3::ShapeOf>(Q);
128+
const auto q_shape_prev_2 = get_dimensions(q_shape, {0, 1});
129+
auto extended_kv_shape = register_new_node<v0::Concat>(ov::NodeVector{q_shape_prev_2, kv_shape_last_2}, 0);
130+
K = register_new_node<v1::Reshape>(K, extended_kv_shape, false);
131+
V = register_new_node<v1::Reshape>(V, extended_kv_shape, false);
132+
}
133+
134+
// need to apply low-triangle mask to attention score.
135+
// two steps, construct the total_sequence x total_sequence triangle, then slice the current length
136+
auto seqlens_1d_scalar = register_new_node<v1::Reshape>(seqlens_1d, one_without_shape, false);
137+
std::shared_ptr<ov::Node> mask_per_line_node =
138+
register_new_node<v4::Range>(register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{}, {0})),
139+
seqlens_1d_scalar,
140+
one_without_shape,
141+
ov::element::i64);
142+
auto hori_range = register_new_node<v0::Unsqueeze>(mask_per_line_node, zero);
143+
auto vert_range = register_new_node<v0::Unsqueeze>(mask_per_line_node, one);
144+
auto triu = register_new_node<v1::Greater>(hori_range, vert_range);
145+
auto typed_zero = register_new_node(v0::Constant::create(T, ov::Shape{}, {0}));
146+
// cf. make_attention_mask@src\plugins\intel_gpu\tests\common\subgraphs_builders.hpp
147+
std::shared_ptr<ov::Node> minus_inf = nullptr;
148+
if (T == ov::element::f32)
149+
minus_inf = register_new_node(v0::Constant::create(T, ov::Shape{}, {-std::numeric_limits<float>::infinity()}));
150+
else if (T == ov::element::f16)
151+
minus_inf =
152+
register_new_node(v0::Constant::create(T, ov::Shape{}, {std::numeric_limits<ov::float16>::lowest()}));
153+
auto atten_mask = register_new_node<v1::Select>(triu, minus_inf, typed_zero);
154+
auto atten_mask_sliced = register_new_node<v8::Slice>(atten_mask, past_sequence_length, seqlens_1d, one, zero);
155+
156+
std::shared_ptr<ov::Node> qga_output;
157+
if (scale != 0.0f) {
158+
auto scale_node = register_new_node(v0::Constant::create(T, Shape{}, {scale}));
159+
qga_output = register_new_node<v13::ScaledDotProductAttention>(Q, K, V, atten_mask_sliced, scale_node, false);
160+
} else {
161+
qga_output = register_new_node<v13::ScaledDotProductAttention>(Q, K, V, atten_mask_sliced, false);
162+
}
163+
164+
// transpose the result from (batch_size, num_heads, sequence_length, head_size)
165+
// to (batch_size, sequence_length, num_heads * head_size)
166+
auto perm = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3}));
167+
auto qga_output_transposed = register_new_node<v1::Transpose>(qga_output, perm);
168+
auto dim_merge_shape = register_new_node(v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1}));
169+
auto output = register_new_node<v1::Reshape>(qga_output_transposed, dim_merge_shape, true)->output(0);
170+
171+
return {output, present_k, present_v};
172+
}
173+
174+
// make split functions is a copy-past from ONNX FE. TODO: move it to one place
175+
ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::make_split(const ov::Output<ov::Node>& value,
176+
int64_t num_splits,
177+
int64_t axis) {
178+
using namespace ov::op;
179+
const auto axis_node = register_new_node(v0::Constant::create(ov::element::i64, ov::Shape{}, {axis}));
180+
const auto split = register_new_node<v1::Split>(value, axis_node, num_splits);
181+
182+
return split->outputs();
183+
}
184+
185+
std::shared_ptr<ov::Node> ov::pass::GroupQueryAttentionDecomposition::get_dimensions(
186+
const std::shared_ptr<ov::op::v3::ShapeOf>& shape,
187+
const std::vector<int>& dims) {
188+
using namespace ov::op;
189+
const auto zero = v0::Constant::create(ov::element::i32, ov::Shape{}, {0});
190+
const auto dims_const = v0::Constant::create(ov::element::i32, ov::Shape{dims.size()}, dims);
191+
return register_new_node<v8::Gather>(shape, dims_const, zero);
192+
}
193+
194+
std::shared_ptr<ov::Node> ov::pass::GroupQueryAttentionDecomposition::get_dimensions(
195+
const std::shared_ptr<ov::Node>& node,
196+
const std::vector<int>& dims) {
197+
return get_dimensions(register_new_node<ov::op::v3::ShapeOf>(node), dims);
198+
}
199+
200+
std::shared_ptr<ov::Node> ov::pass::GroupQueryAttentionDecomposition::rotaryEmbedding(
201+
ov::Output<ov::Node> input,
202+
ov::Output<ov::Node> past_seqlen,
203+
std::shared_ptr<ov::Node> seqlen_k,
204+
std::shared_ptr<ov::Node> cos_cache,
205+
std::shared_ptr<ov::Node> sin_cache,
206+
std::shared_ptr<ov::Node> dim_head_size,
207+
bool interleaved) {
208+
using namespace ov::op;
209+
auto zero = v0::Constant::create(ov::element::i64, ov::Shape{1}, {0});
210+
auto one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {1});
211+
212+
auto slice_cache_dim_shape = seqlen_k;
213+
214+
auto cos = register_new_node<v8::Slice>(cos_cache, past_seqlen, slice_cache_dim_shape, one, zero);
215+
auto sin = register_new_node<v8::Slice>(sin_cache, past_seqlen, slice_cache_dim_shape, one, zero);
216+
217+
if (interleaved) {
218+
auto two = v0::Constant::create(ov::element::i64, ov::Shape{1}, {2});
219+
220+
auto cache_shape = register_new_node<v3::ShapeOf>(cos_cache);
221+
auto cache_last_dim = get_dimensions(cos_cache, {-1});
222+
223+
auto input_shape = register_new_node<v3::ShapeOf>(input);
224+
225+
auto dim_bns = get_dimensions(input_shape, {0, 1, 2});
226+
std::shared_ptr<ov::Node> half_last_dim = cache_last_dim;
227+
228+
auto negtive_one = v0::Constant::create(ov::element::i64, ov::Shape{1}, {-1});
229+
auto split_input_shape = register_new_node<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim, two}, 0);
230+
auto reshaped_input = register_new_node<v1::Reshape>(input, split_input_shape, false);
231+
232+
auto in_split = make_split(reshaped_input, 2, -1);
233+
split_input_shape = register_new_node<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim}, 0);
234+
auto in_split_0 = register_new_node<v1::Reshape>(in_split[0], split_input_shape, false);
235+
auto in_split_1 = register_new_node<v1::Reshape>(in_split[1], split_input_shape, false);
236+
237+
auto res_0 = register_new_node<v1::Subtract>(register_new_node<v1::Multiply>(in_split_0, cos),
238+
register_new_node<v1::Multiply>(in_split_1, sin));
239+
auto res_1 = register_new_node<v1::Add>(register_new_node<v1::Multiply>(in_split_0, sin),
240+
register_new_node<v1::Multiply>(in_split_1, cos));
241+
242+
split_input_shape = register_new_node<v0::Concat>(ov::NodeVector{dim_bns, half_last_dim, one}, 0);
243+
auto res_0_5d = register_new_node<v1::Reshape>(res_0, split_input_shape, false);
244+
auto res_1_5d = register_new_node<v1::Reshape>(res_1, split_input_shape, false);
245+
246+
auto concat_ret = register_new_node<v0::Concat>(ov::NodeVector{res_0_5d, res_1_5d}, -1);
247+
return register_new_node<v1::Reshape>(concat_ret, input_shape, false);
248+
} else {
249+
auto in_split = make_split(input, 2, -1);
250+
auto res_0 = register_new_node<v1::Subtract>(register_new_node<v1::Multiply>(in_split[0], cos),
251+
register_new_node<v1::Multiply>(in_split[1], sin));
252+
auto res_1 = register_new_node<v1::Add>(register_new_node<v1::Multiply>(in_split[0], sin),
253+
register_new_node<v1::Multiply>(in_split[1], cos));
254+
255+
return register_new_node<v0::Concat>(ov::NodeVector{res_0, res_1}, -1);
256+
}
257+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
#pragma once
5+
6+
#include "openvino/op/op.hpp"
7+
8+
namespace ov::op::internal {
9+
10+
// This is an experimental operation that is implemented in the plugins.
11+
class OPENVINO_API GroupQueryAttention : public Op {
12+
public:
13+
OPENVINO_OP("GroupQueryAttention");
14+
15+
GroupQueryAttention() = default;
16+
GroupQueryAttention(const ov::OutputVector& args,
17+
int64_t num_heads,
18+
int64_t kv_num_heads,
19+
float scale,
20+
bool do_rotary,
21+
bool rotary_interleaved);
22+
void validate_and_infer_types() override;
23+
bool visit_attributes(AttributeVisitor& visitor) override;
24+
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
25+
26+
int64_t get_num_heads() const {
27+
return m_num_heads;
28+
}
29+
int64_t get_kv_num_heads() const {
30+
return m_kv_num_heads;
31+
}
32+
float get_scale() const {
33+
return m_scale;
34+
}
35+
bool get_do_rotary() const {
36+
return m_do_rotary;
37+
}
38+
bool get_rotary_interleaved() const {
39+
return m_rotary_interleaved;
40+
}
41+
42+
private:
43+
int64_t m_num_heads = 0;
44+
int64_t m_kv_num_heads = 0;
45+
float m_scale = 0;
46+
bool m_do_rotary = false;
47+
bool m_rotary_interleaved = false;
48+
};
49+
50+
} // namespace ov::op::internal

0 commit comments

Comments
 (0)