Skip to content

Commit d8c7414

Browse files
e-ddykimyeonbok
andauthored
[GPU] Calculate rope in f32 for gpt-oss model in PA mode (#32868)
### Details: - backport of #32844 ### Tickets: - 176323 Co-authored-by: Taylor Yeonbok Lee <[email protected]>
1 parent 163831a commit d8c7414

File tree

1 file changed

+49
-34
lines changed

1 file changed

+49
-34
lines changed

src/plugins/intel_gpu/src/plugin/transformations/increase_position_ids_precision.cpp

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -172,83 +172,98 @@ IncreasePositionIdsPrecisionForLtxVideo::IncreasePositionIdsPrecisionForLtxVideo
172172
this->register_matcher(m, callback);
173173
}
174174

175+
// TODO : To have a fused rope kernel for this pattern
175176
IncreasePositionIdsPrecisionForGPTOSS::IncreasePositionIdsPrecisionForGPTOSS() {
176177
using namespace ov::pass::pattern;
177178
using ov::pass::pattern::op::Or;
178179

179180
auto broadcast_freq = wrap_type<ov::op::v3::Broadcast>({any_input(), any_input()});
181+
auto convert_broadcast_freq = wrap_type<ov::op::v0::Convert>({broadcast_freq});
180182

181183
auto convert_pos_id_to_i32 = wrap_type<ov::op::v0::Convert>({any_input()});
182-
auto unsqueeze_pos_id = wrap_type<ov::op::v0::Unsqueeze>({convert_pos_id_to_i32, any_input()});
184+
auto unsqueeze_pos_id_1 = wrap_type<ov::op::v0::Unsqueeze>({convert_pos_id_to_i32, any_input()});
185+
auto unsqueeze_pos_id_2 = wrap_type<ov::op::v1::Reshape>({convert_pos_id_to_i32, any_input()});
186+
auto unsqueeze_pos_id = std::make_shared<Or>(OutputVector{unsqueeze_pos_id_1, unsqueeze_pos_id_2});
183187
auto convert_pos_id_to_f16 = wrap_type<ov::op::v0::Convert>({unsqueeze_pos_id});
184188

185-
auto convert_broadcast_freq = wrap_type<ov::op::v0::Convert>({broadcast_freq});
186-
187189
auto broadcast_freq_ = std::make_shared<Or>(OutputVector{broadcast_freq, convert_broadcast_freq});
188190

189191
auto matmul_freq_pos_id = wrap_type<ov::op::v0::MatMul>({broadcast_freq_, convert_pos_id_to_f16});
190-
auto transpose = wrap_type<ov::op::v1::Transpose>({matmul_freq_pos_id, any_input()});
192+
auto reshape_matmul = wrap_type<ov::op::v1::Reshape>({matmul_freq_pos_id, any_input()});
193+
auto transpose_matmul = wrap_type<ov::op::v1::Transpose>({matmul_freq_pos_id, any_input()});
194+
auto transpose_or_reshape = std::make_shared<Or>(OutputVector{transpose_matmul, reshape_matmul});
191195

192-
auto sin = wrap_type<ov::op::v0::Sin>({transpose});
193-
auto sin_convert = wrap_type<ov::op::v0::Convert>({sin});
194-
auto sin_ = std::make_shared<Or>(OutputVector{sin, sin_convert});
196+
auto sin_ = wrap_type<ov::op::v0::Sin>({transpose_or_reshape});
197+
auto sin_convert = wrap_type<ov::op::v0::Convert>({sin_});
198+
auto sin = std::make_shared<Or>(OutputVector{sin_, sin_convert});
195199

196-
auto cos = wrap_type<ov::op::v0::Cos>({transpose});
200+
auto cos = wrap_type<ov::op::v0::Cos>({transpose_or_reshape});
197201
auto cos_convert = wrap_type<ov::op::v0::Convert>({cos});
198202
auto cos_ = std::make_shared<Or>(OutputVector{cos, cos_convert});
199203

200204
auto scale_const_sin = wrap_type<ov::op::v0::Constant>();
201205
auto scale_const_sin_convert = wrap_type<ov::op::v0::Convert>({scale_const_sin});
202206
auto scale_const_sin_ = std::make_shared<Or>(OutputVector{scale_const_sin, scale_const_sin_convert});
203-
auto mul_sin_scale = wrap_type<ov::op::v1::Multiply>({sin_, scale_const_sin_});
207+
auto mul_sin_scale = wrap_type<ov::op::v1::Multiply>({sin, scale_const_sin_});
204208

205209
auto scale_const_cos = wrap_type<ov::op::v0::Constant>();
206210
auto scale_const_cos_convert = wrap_type<ov::op::v0::Convert>({scale_const_cos});
207211
auto scale_const_cos_ = std::make_shared<Or>(OutputVector{scale_const_cos, scale_const_cos_convert});
208212
auto mul_cos_scale = wrap_type<ov::op::v1::Multiply>({cos_, scale_const_cos_});
209213

210-
auto unsqueeze_mul_sin_scale = wrap_type<ov::op::v0::Unsqueeze>({mul_sin_scale, any_input()});
211-
auto mul_q_sin = wrap_type<ov::op::v1::Multiply>({any_input()/* q_second_half*/, unsqueeze_mul_sin_scale});
214+
auto unsqueeze_mul_sin_scale_ = wrap_type<ov::op::v0::Unsqueeze>({mul_sin_scale, any_input()});
215+
auto reshape_mul_sin_scale_ = wrap_type<ov::op::v1::Reshape>({mul_sin_scale, any_input()});
216+
auto reshape_mul_sin_scale = std::make_shared<Or>(OutputVector{unsqueeze_mul_sin_scale_, reshape_mul_sin_scale_});
217+
218+
auto mul_qk_sin = wrap_type<ov::op::v1::Multiply>({any_input()/*second_half*/, reshape_mul_sin_scale});
212219

213-
auto unsqueeze_mul_cos_scale = wrap_type<ov::op::v0::Unsqueeze>({mul_cos_scale, any_input()});
214-
auto mul_q_cos = wrap_type<ov::op::v1::Multiply>({any_input()/* q_second_half*/, unsqueeze_mul_cos_scale});
220+
auto unsqueeze_mul_cos_scale_ = wrap_type<ov::op::v0::Unsqueeze>({mul_cos_scale, any_input()});
221+
auto reshape_mul_cos_scale_ = wrap_type<ov::op::v1::Reshape>({mul_cos_scale, any_input()});
222+
auto reshape_mul_cos_scale = std::make_shared<Or>(OutputVector{unsqueeze_mul_cos_scale_, reshape_mul_cos_scale_});
223+
auto mul_qk_cos = wrap_type<ov::op::v1::Multiply>({any_input()/* first_half*/, reshape_mul_cos_scale});
215224

216-
auto q_half_mul1 = wrap_type<ov::op::v1::Multiply>({any_input(), any_input()});
217-
auto q_half_mul2 = wrap_type<ov::op::v1::Multiply>({q_half_mul1, any_input()});
218-
auto q_half_first = wrap_type<ov::op::v1::Add>({mul_q_cos, q_half_mul2});
225+
auto qk_half_mul1 = wrap_type<ov::op::v1::Multiply>({any_input(), any_input()});
226+
auto qk_half_mul2_1 = wrap_type<ov::op::v1::Multiply>({qk_half_mul1, any_input()});
227+
auto qk_half_mul2_2 = wrap_type<ov::op::v1::Multiply>({any_input(), qk_half_mul1});
228+
auto qk_half_mul2 = std::make_shared<Or>(OutputVector{qk_half_mul2_1, qk_half_mul2_2});
229+
auto qk_half_first_1 = wrap_type<ov::op::v1::Add>({mul_qk_cos, qk_half_mul2});
230+
auto qk_half_first_2 = wrap_type<ov::op::v1::Add>({qk_half_mul2, mul_qk_cos});
231+
auto qk_half_first = std::make_shared<Or>(OutputVector{qk_half_first_1, qk_half_first_2});
219232

220-
auto q_half_mul4 = wrap_type<ov::op::v1::Multiply>({any_input(), any_input()});
221-
auto q_half_second = wrap_type<ov::op::v1::Add>({mul_q_sin, q_half_mul4});
233+
auto qk_half_mul4 = wrap_type<ov::op::v1::Multiply>({any_input(), any_input()});
234+
auto qk_half_second_1 = wrap_type<ov::op::v1::Add>({mul_qk_sin, qk_half_mul4});
235+
auto qk_half_second_2 = wrap_type<ov::op::v1::Add>({qk_half_mul4, mul_qk_sin});
236+
auto qk_half_second = std::make_shared<Or>(OutputVector{qk_half_second_1, qk_half_second_2});
222237

223-
auto concat_q_1 = wrap_type<ov::op::v0::Concat>({q_half_second, q_half_first});
224-
auto concat_q_2 = wrap_type<ov::op::v0::Concat>({q_half_first, q_half_second});
225-
auto concat_q = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{concat_q_1, concat_q_2});
238+
auto concat_qk_1 = wrap_type<ov::op::v0::Concat>({qk_half_second, qk_half_first});
239+
auto concat_qk_2 = wrap_type<ov::op::v0::Concat>({qk_half_first, qk_half_second});
240+
auto concat_qk = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{concat_qk_1, concat_qk_2});
226241

227242
ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
228243
const auto& pattern_map = m.get_pattern_value_map();
229244
bool matched = false;
230-
if (pattern_map.count(concat_q_1) > 0) {
245+
if (pattern_map.count(concat_qk_1) > 0) {
231246
matched = true;
232-
} else if (pattern_map.count(concat_q_2) > 0) {
247+
} else if (pattern_map.count(concat_qk_2) > 0) {
233248
matched = true;
234249
}
235-
if (!matched || transformation_callback(concat_q))
250+
if (!matched || transformation_callback(concat_qk))
236251
return false;
237252

238253
std::shared_ptr<ov::op::v0::Concat> output_concat_node;
239-
if (pattern_map.count(concat_q_1) > 0) {
240-
output_concat_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(concat_q_1).get_node_shared_ptr());
241-
} else if (pattern_map.count(concat_q_2) > 0) {
242-
output_concat_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(concat_q_2).get_node_shared_ptr());
254+
if (pattern_map.count(concat_qk_1) > 0) {
255+
output_concat_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(concat_qk_1).get_node_shared_ptr());
256+
} else if (pattern_map.count(concat_qk_2) > 0) {
257+
output_concat_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at(concat_qk_2).get_node_shared_ptr());
243258
}
244259
auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(pattern_map.at(matmul_freq_pos_id).get_node_shared_ptr());
245260
auto mul_node1 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(mul_sin_scale).get_node_shared_ptr());
246261
auto mul_node2 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(mul_cos_scale).get_node_shared_ptr());
247-
auto mul_node3 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(mul_q_sin).get_node_shared_ptr());
248-
auto mul_node4 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(mul_q_cos).get_node_shared_ptr());
249-
auto mul_node5 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(q_half_mul1).get_node_shared_ptr());
250-
auto mul_node6 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(q_half_mul2).get_node_shared_ptr());
251-
auto mul_node8 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(q_half_mul4).get_node_shared_ptr());
262+
auto mul_node3 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(mul_qk_sin).get_node_shared_ptr());
263+
auto mul_node4 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(mul_qk_cos).get_node_shared_ptr());
264+
auto mul_node5 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(qk_half_mul1).get_node_shared_ptr());
265+
auto mul_node6 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(qk_half_mul2).get_node_shared_ptr());
266+
auto mul_node8 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at(qk_half_mul4).get_node_shared_ptr());
252267

253268
const auto desired_et = ov::element::f32;
254269
const auto original_et = output_concat_node->get_output_element_type(0);
@@ -268,7 +283,7 @@ IncreasePositionIdsPrecisionForGPTOSS::IncreasePositionIdsPrecisionForGPTOSS() {
268283
return true;
269284
};
270285

271-
auto m = std::make_shared<ov::pass::pattern::Matcher>(concat_q, "IncreasePositionIdsPrecisionForGPTOSS");
286+
auto m = std::make_shared<ov::pass::pattern::Matcher>(concat_qk, "IncreasePositionIdsPrecisionForGPTOSS");
272287
this->register_matcher(m, callback);
273288
}
274289

0 commit comments

Comments
 (0)