@@ -172,83 +172,98 @@ IncreasePositionIdsPrecisionForLtxVideo::IncreasePositionIdsPrecisionForLtxVideo
172172 this ->register_matcher (m, callback);
173173}
174174
175+ // TODO : To have a fused rope kernel for this pattern
175176IncreasePositionIdsPrecisionForGPTOSS::IncreasePositionIdsPrecisionForGPTOSS () {
176177 using namespace ov ::pass::pattern;
177178 using ov::pass::pattern::op::Or;
178179
179180 auto broadcast_freq = wrap_type<ov::op::v3::Broadcast>({any_input (), any_input ()});
181+ auto convert_broadcast_freq = wrap_type<ov::op::v0::Convert>({broadcast_freq});
180182
181183 auto convert_pos_id_to_i32 = wrap_type<ov::op::v0::Convert>({any_input ()});
182- auto unsqueeze_pos_id = wrap_type<ov::op::v0::Unsqueeze>({convert_pos_id_to_i32, any_input ()});
184+ auto unsqueeze_pos_id_1 = wrap_type<ov::op::v0::Unsqueeze>({convert_pos_id_to_i32, any_input ()});
185+ auto unsqueeze_pos_id_2 = wrap_type<ov::op::v1::Reshape>({convert_pos_id_to_i32, any_input ()});
186+ auto unsqueeze_pos_id = std::make_shared<Or>(OutputVector{unsqueeze_pos_id_1, unsqueeze_pos_id_2});
183187 auto convert_pos_id_to_f16 = wrap_type<ov::op::v0::Convert>({unsqueeze_pos_id});
184188
185- auto convert_broadcast_freq = wrap_type<ov::op::v0::Convert>({broadcast_freq});
186-
187189 auto broadcast_freq_ = std::make_shared<Or>(OutputVector{broadcast_freq, convert_broadcast_freq});
188190
189191 auto matmul_freq_pos_id = wrap_type<ov::op::v0::MatMul>({broadcast_freq_, convert_pos_id_to_f16});
190- auto transpose = wrap_type<ov::op::v1::Transpose>({matmul_freq_pos_id, any_input ()});
192+ auto reshape_matmul = wrap_type<ov::op::v1::Reshape>({matmul_freq_pos_id, any_input ()});
193+ auto transpose_matmul = wrap_type<ov::op::v1::Transpose>({matmul_freq_pos_id, any_input ()});
194+ auto transpose_or_reshape = std::make_shared<Or>(OutputVector{transpose_matmul, reshape_matmul});
191195
192- auto sin = wrap_type<ov::op::v0::Sin>({transpose });
193- auto sin_convert = wrap_type<ov::op::v0::Convert>({sin });
194- auto sin_ = std::make_shared<Or>(OutputVector{sin , sin_convert});
196+ auto sin_ = wrap_type<ov::op::v0::Sin>({transpose_or_reshape });
197+ auto sin_convert = wrap_type<ov::op::v0::Convert>({sin_ });
198+ auto sin = std::make_shared<Or>(OutputVector{sin_ , sin_convert});
195199
196- auto cos = wrap_type<ov::op::v0::Cos>({transpose });
200+ auto cos = wrap_type<ov::op::v0::Cos>({transpose_or_reshape });
197201 auto cos_convert = wrap_type<ov::op::v0::Convert>({cos});
198202 auto cos_ = std::make_shared<Or>(OutputVector{cos, cos_convert});
199203
200204 auto scale_const_sin = wrap_type<ov::op::v0::Constant>();
201205 auto scale_const_sin_convert = wrap_type<ov::op::v0::Convert>({scale_const_sin});
202206 auto scale_const_sin_ = std::make_shared<Or>(OutputVector{scale_const_sin, scale_const_sin_convert});
203- auto mul_sin_scale = wrap_type<ov::op::v1::Multiply>({sin_ , scale_const_sin_});
207+ auto mul_sin_scale = wrap_type<ov::op::v1::Multiply>({sin , scale_const_sin_});
204208
205209 auto scale_const_cos = wrap_type<ov::op::v0::Constant>();
206210 auto scale_const_cos_convert = wrap_type<ov::op::v0::Convert>({scale_const_cos});
207211 auto scale_const_cos_ = std::make_shared<Or>(OutputVector{scale_const_cos, scale_const_cos_convert});
208212 auto mul_cos_scale = wrap_type<ov::op::v1::Multiply>({cos_, scale_const_cos_});
209213
210- auto unsqueeze_mul_sin_scale = wrap_type<ov::op::v0::Unsqueeze>({mul_sin_scale, any_input ()});
211- auto mul_q_sin = wrap_type<ov::op::v1::Multiply>({any_input ()/* q_second_half*/ , unsqueeze_mul_sin_scale});
214+ auto unsqueeze_mul_sin_scale_ = wrap_type<ov::op::v0::Unsqueeze>({mul_sin_scale, any_input ()});
215+ auto reshape_mul_sin_scale_ = wrap_type<ov::op::v1::Reshape>({mul_sin_scale, any_input ()});
216+ auto reshape_mul_sin_scale = std::make_shared<Or>(OutputVector{unsqueeze_mul_sin_scale_, reshape_mul_sin_scale_});
217+
218+ auto mul_qk_sin = wrap_type<ov::op::v1::Multiply>({any_input ()/* second_half*/ , reshape_mul_sin_scale});
212219
213- auto unsqueeze_mul_cos_scale = wrap_type<ov::op::v0::Unsqueeze>({mul_cos_scale, any_input ()});
214- auto mul_q_cos = wrap_type<ov::op::v1::Multiply>({any_input ()/* q_second_half*/ , unsqueeze_mul_cos_scale});
220+ auto unsqueeze_mul_cos_scale_ = wrap_type<ov::op::v0::Unsqueeze>({mul_cos_scale, any_input ()});
221+ auto reshape_mul_cos_scale_ = wrap_type<ov::op::v1::Reshape>({mul_cos_scale, any_input ()});
222+ auto reshape_mul_cos_scale = std::make_shared<Or>(OutputVector{unsqueeze_mul_cos_scale_, reshape_mul_cos_scale_});
223+ auto mul_qk_cos = wrap_type<ov::op::v1::Multiply>({any_input ()/* first_half*/ , reshape_mul_cos_scale});
215224
216- auto q_half_mul1 = wrap_type<ov::op::v1::Multiply>({any_input (), any_input ()});
217- auto q_half_mul2 = wrap_type<ov::op::v1::Multiply>({q_half_mul1, any_input ()});
218- auto q_half_first = wrap_type<ov::op::v1::Add>({mul_q_cos, q_half_mul2});
225+ auto qk_half_mul1 = wrap_type<ov::op::v1::Multiply>({any_input (), any_input ()});
226+ auto qk_half_mul2_1 = wrap_type<ov::op::v1::Multiply>({qk_half_mul1, any_input ()});
227+ auto qk_half_mul2_2 = wrap_type<ov::op::v1::Multiply>({any_input (), qk_half_mul1});
228+ auto qk_half_mul2 = std::make_shared<Or>(OutputVector{qk_half_mul2_1, qk_half_mul2_2});
229+ auto qk_half_first_1 = wrap_type<ov::op::v1::Add>({mul_qk_cos, qk_half_mul2});
230+ auto qk_half_first_2 = wrap_type<ov::op::v1::Add>({qk_half_mul2, mul_qk_cos});
231+ auto qk_half_first = std::make_shared<Or>(OutputVector{qk_half_first_1, qk_half_first_2});
219232
220- auto q_half_mul4 = wrap_type<ov::op::v1::Multiply>({any_input (), any_input ()});
221- auto q_half_second = wrap_type<ov::op::v1::Add>({mul_q_sin, q_half_mul4});
233+ auto qk_half_mul4 = wrap_type<ov::op::v1::Multiply>({any_input (), any_input ()});
234+ auto qk_half_second_1 = wrap_type<ov::op::v1::Add>({mul_qk_sin, qk_half_mul4});
235+ auto qk_half_second_2 = wrap_type<ov::op::v1::Add>({qk_half_mul4, mul_qk_sin});
236+ auto qk_half_second = std::make_shared<Or>(OutputVector{qk_half_second_1, qk_half_second_2});
222237
223- auto concat_q_1 = wrap_type<ov::op::v0::Concat>({q_half_second, q_half_first });
224- auto concat_q_2 = wrap_type<ov::op::v0::Concat>({q_half_first, q_half_second });
225- auto concat_q = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{concat_q_1, concat_q_2 });
238+ auto concat_qk_1 = wrap_type<ov::op::v0::Concat>({qk_half_second, qk_half_first });
239+ auto concat_qk_2 = wrap_type<ov::op::v0::Concat>({qk_half_first, qk_half_second });
240+ auto concat_qk = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{concat_qk_1, concat_qk_2 });
226241
227242 ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
228243 const auto & pattern_map = m.get_pattern_value_map ();
229244 bool matched = false ;
230- if (pattern_map.count (concat_q_1 ) > 0 ) {
245+ if (pattern_map.count (concat_qk_1 ) > 0 ) {
231246 matched = true ;
232- } else if (pattern_map.count (concat_q_2 ) > 0 ) {
247+ } else if (pattern_map.count (concat_qk_2 ) > 0 ) {
233248 matched = true ;
234249 }
235- if (!matched || transformation_callback (concat_q ))
250+ if (!matched || transformation_callback (concat_qk ))
236251 return false ;
237252
238253 std::shared_ptr<ov::op::v0::Concat> output_concat_node;
239- if (pattern_map.count (concat_q_1 ) > 0 ) {
240- output_concat_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at (concat_q_1 ).get_node_shared_ptr ());
241- } else if (pattern_map.count (concat_q_2 ) > 0 ) {
242- output_concat_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at (concat_q_2 ).get_node_shared_ptr ());
254+ if (pattern_map.count (concat_qk_1 ) > 0 ) {
255+ output_concat_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at (concat_qk_1 ).get_node_shared_ptr ());
256+ } else if (pattern_map.count (concat_qk_2 ) > 0 ) {
257+ output_concat_node = ov::as_type_ptr<ov::op::v0::Concat>(pattern_map.at (concat_qk_2 ).get_node_shared_ptr ());
243258 }
244259 auto matmul_node = ov::as_type_ptr<ov::op::v0::MatMul>(pattern_map.at (matmul_freq_pos_id).get_node_shared_ptr ());
245260 auto mul_node1 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (mul_sin_scale).get_node_shared_ptr ());
246261 auto mul_node2 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (mul_cos_scale).get_node_shared_ptr ());
247- auto mul_node3 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (mul_q_sin ).get_node_shared_ptr ());
248- auto mul_node4 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (mul_q_cos ).get_node_shared_ptr ());
249- auto mul_node5 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (q_half_mul1 ).get_node_shared_ptr ());
250- auto mul_node6 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (q_half_mul2 ).get_node_shared_ptr ());
251- auto mul_node8 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (q_half_mul4 ).get_node_shared_ptr ());
262+ auto mul_node3 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (mul_qk_sin ).get_node_shared_ptr ());
263+ auto mul_node4 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (mul_qk_cos ).get_node_shared_ptr ());
264+ auto mul_node5 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (qk_half_mul1 ).get_node_shared_ptr ());
265+ auto mul_node6 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (qk_half_mul2 ).get_node_shared_ptr ());
266+ auto mul_node8 = ov::as_type_ptr<ov::op::v1::Multiply>(pattern_map.at (qk_half_mul4 ).get_node_shared_ptr ());
252267
253268 const auto desired_et = ov::element::f32 ;
254269 const auto original_et = output_concat_node->get_output_element_type (0 );
@@ -268,7 +283,7 @@ IncreasePositionIdsPrecisionForGPTOSS::IncreasePositionIdsPrecisionForGPTOSS() {
268283 return true ;
269284 };
270285
271- auto m = std::make_shared<ov::pass::pattern::Matcher>(concat_q , " IncreasePositionIdsPrecisionForGPTOSS" );
286+ auto m = std::make_shared<ov::pass::pattern::Matcher>(concat_qk , " IncreasePositionIdsPrecisionForGPTOSS" );
272287 this ->register_matcher (m, callback);
273288}
274289
0 commit comments