Skip to content

Commit d36d3d6

Browse files
[TRANSFORMATIONS] Fix sliding window handling for gpt-oss in SDPAToPA (#32939)
It was observed that the initial implementation of sliding window handling for gpt-oss was incorrect and instead of using the real sliding window value, the 0 stub was always used. * Fix the pattern that captures the real subgraph in the model and extracts the real sliding window value. * Fix the test that was implemented based on the incorrect post-transformation graph. Tickets: [CVS-176323](https://jira.devtools.intel.com/browse/CVS-176323) Signed-off-by: Andrii Staikov [[email protected]](mailto:[email protected]) Signed-off-by: Andrii Staikov <[email protected]>
1 parent 3539f6c commit d36d3d6

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

src/common/transformations/src/transformations/sdpa_to_paged_attention/state_management_pattern.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,22 +207,19 @@ static std::tuple<std::shared_ptr<ov::Node>, std::shared_ptr<ov::Node>> gpt_oss_
207207
auto q_idx = pattern::any_input();
208208
auto kv_idx = pattern::any_input();
209209

210-
auto kv_idx_opt_conv_0 = pattern::optional<v0::Convert>();
211-
auto kv_idx_opt_conv_1 = pattern::optional<v0::Convert>(kv_idx_opt_conv_0);
212-
auto less_eq = pattern::wrap_type<v1::LessEqual>({q_idx, kv_idx_opt_conv_1});
210+
auto kv_idx_opt_conv = pattern::optional<v0::Convert>(kv_idx);
213211

214212
auto offset = wrap_type<v0::Constant>();
215213

216214
auto add = wrap_type<v1::Add>({q_idx, offset});
217-
auto opt_conv_2 = pattern::optional<v0::Convert>(add);
218-
auto greater = pattern::wrap_type<v1::Greater>({kv_idx_opt_conv_1, opt_conv_2});
215+
auto greater = pattern::wrap_type<v1::Greater>({kv_idx_opt_conv, add});
219216
auto bitwise_and = pattern::wrap_type<v13::BitwiseAnd>({any_input(), greater});
220217
auto bitwise_and_1 = pattern::wrap_type<v13::BitwiseAnd>({bitwise_and, any_input()});
221218
auto bitwise_and_2 = pattern::wrap_type<v13::BitwiseAnd>({any_input(), bitwise_and_1});
222219
auto bitwise_and_3 = pattern::wrap_type<v13::BitwiseAnd>({bitwise_and_2, any_input()});
223220
auto broadcast = pattern::wrap_type<v3::Broadcast>({bitwise_and_3, any_input()});
224221
auto select = pattern::wrap_type<v1::Select>({broadcast, any_input(), any_input()});
225-
auto mask = pattern::wrap_type<v1::StridedSlice>({select, any_input(), any_input(), any_input()});
222+
auto mask = pattern::wrap_type<v8::Slice>({select, any_input(), any_input(), any_input(), any_input()});
226223

227224
return {mask, offset};
228225
}

src/common/transformations/tests/op_conversions/sdpa_to_paged_attention_test.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2646,8 +2646,11 @@ TEST_F(SDPAToPATest, SDPAToPA_gpt_oss_General) {
26462646
}),
26472647
MOCK_VALUE);
26482648

2649-
auto scale = v0::Constant::create(element::f32, {}, {0.125000f});
2650-
auto sliding_window = v0::Constant::create(element::i32, {}, {0});
2649+
auto sliding_window_neg = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-128.0f});
2650+
auto Squeeze2 = makeOP<v15::Squeeze>({sliding_window_neg}, {{"allow_axis_skip", false}});
2651+
auto Convert16 = makeOP<v0::Convert>({Squeeze2}, {{"destination_type", "i32"}});
2652+
auto sliding_window = makeOP<v1::Multiply>({Convert16, -1}, {{"auto_broadcast", "numpy"}});
2653+
auto scale = v0::Constant::create(element::f32, {}, {0.1250f});
26512654
auto alibi_slopes_stub = v0::Constant::create(element::f32, Shape{0}, {});
26522655
auto PagedAttentionExtension =
26532656
std::make_shared<ov::op::PagedAttentionExtension>(OutputVector{Reshape1,

0 commit comments

Comments
 (0)