Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/plugins/intel_gpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ GPU Plugin contains the following components:
* [Run benchmark from device_mem](./docs/use_device_mem.md)

## Documentation on dynamic-shape
This contents explain the internal implementation of dynamic shape support in the GPU Plugin. For general usage of dynamic shape and limitations of the GPU plugin, please refer to this link: [GPU Device — OpenVINO™ documentation - Version(2023.1)](https://docs.openvino.ai/2023.1/openvino_docs_OV_UG_supported_plugins_GPU.html#dynamic-shapes).
This contents explain the internal implementation of dynamic shape support in the GPU Plugin. For general usage of dynamic shape and limitations of the GPU plugin, please refer to this link: [GPU Device — OpenVINO™ documentation - Version(2025)](https://docs.openvino.ai/2025/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html#dynamic-shapes).

* [Overall flow for dynamic shape execution](./docs/dynamic_shape/overall_flow.md)
* Implementation details
Expand All @@ -44,6 +44,8 @@ This contents explain the internal implementation of dynamic shape support in th
<!-- * weight compression (TBD)) -->
* Optimization features
* [Memory preallocation](./docs/dynamic_shape/memory_preallocation.md)
* [Runtime operation skip](./docs/dynamic_shape/runtime_skip.md)
* [KV cache optimization](./docs/dynamic_shape/kv_cache.md)
<!-- * Fake alignment of shape (TBD)
* Shape-of subgraph on CPU (TBD)
* Runtime buffer fusing (TBD)
Expand Down
115 changes: 115 additions & 0 deletions src/plugins/intel_gpu/docs/dynamic_shape/kv_cache.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# KV cache management in stateful model
## Description

For auto-regressive transformer models, KV (key-value) cache plays a pivotal role to avoid duplicated computation for past tokens. In OpenVINO, management of the KV (key-value) cache differs between the [continuous batch (CB) pipeline](https://docs.openvino.ai/2025/model-server/ovms_demos_continuous_batching.html) and the non-CB pipeline.

In CB pipeline with openvino_genai, an SDPA (Scaled Dot Product Attention) operation in the OV IR (intermediate representation) is converted to a PagedAttention opertation. Then, the KV cache memories are managed by CB pipeline. CB pipeline is the default mode of openvino_genai for running LLMs.

On the other hand, we can use stateful model for Non-CB pipeline for KV cache optimization. In this case, KV cache memory is managed by the plugin.

In this document, KV cache management in stateful model within a non-CB pipeline is explained.


## Basic structure of SDPA and kv cache in a stateful model

The following diagram shows a typical pattern of the KV cache in a stateful LLM model.

```mermaid
graph TB
StateInit["State<br/>Initialize<br/>Subgraph"]
ReadValue["ReadValue<br/>(Past KV)"]
BeamIdx["Parameter<br/>beam_idx"]
Gather["Gather"]
KVProj["KV projection<br/>for new token"]
Concat["Concat"]
SDPA["SDPA"]
Assign["Assign<br/>(Present KV)"]
VariableState["VariableState"]

%% flows
StateInit --> ReadValue
ReadValue -.-> VariableState
ReadValue --> Gather
BeamIdx --> Gather
Gather --> Concat
KVProj --> Concat
Concat --> SDPA
Concat --> Assign
Assign -.-> VariableState
```

Here, the KV cache from previous tokens is stored in the VariableState's memory and loaded by the ReadValue operation. It is then concatenated with the KV produced from the new token. The combined KV is written back to the VariableState's memory via the Assign operation as the updated (present) KV cache. This present KV is consumed by the SDPA operation. Also you can see the Gather operation after the ReadValue operation, which is actually used only when the beam_idx input is given for a beam search sampling.

Here, the GPU plugin applied optimizations targeting the three key points :
1. reducing the concatenation overhead
2. improving KV cache memory (VariableState) allocation efficiency, and
3. eliminating the gather operation for beam search (i.e., indirect SDPA).

To this end, we introduced an internal operation of [ov::intel_gpu::KVCache](https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp).


First, the Concatenation and Assign operations are fused into a single KVCache operation by the [KVCacheFusion](https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_fusion.hpp) transform. Next, the [IndirectKVCache](https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_gpu/src/plugin/transformations/indirect_kv_cache.hpp) transform further fuses the Gather operation into the same KVCache operation. (You can refer to the diagrams in the header files of each pass for detailed fusion patterns.)

The resulting graph after the transformations is as follows:

```mermaid
graph TB

StateInit["State<br/>Initialize<br/>Subgraph"]
BeamIdx["Parameter<br/>beam_idx"]
NewKV["KV projection<br/>for new token"]
KVCache["KVCache"]
SDPA["SDPA"]
VariableState["VariableState"]
ReadValue["ReadValue"]

%% main flows
StateInit --> ReadValue
ReadValue --> KVCache
BeamIdx --> KVCache
NewKV --> KVCache

KVCache -- beam_table --> SDPA
KVCache -- KV --> SDPA

KVCache -.-> VariableState
ReadValue -.-> VariableState
```

## Optimizations for KVCache operation

The following sections describe the policies and optimizations applied to the KVCache operation.

### In-place concatenation

As described above, the KV projection data used as the input to the current SDPA step is formed by concatenating the KV data from past tokens with that of the new token. In the original graph, this concatenation requires two memory copies : one for past KV and another for the new KV. With intel_gpu::KVCache, we reduce this to a single copy by:
1) allocating sufficient contiguous memory in advance,
2) applying appropriate padding so the new KV data is written directly to its target location, and
3) preserving the existing past KV data and copying only the new KV data.

The relevant code sections to enable in-place kv cache is as follows :
- [prepare_buffer_fusing](https://github.com/openvinotoolkit/openvino/blob/792ddf38fe3da130c2b3e11662374ec9ca3a2624/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp#L913)
- At compilation time, it sets dynamic padding for the seq_len dimension of a `kv_cache` node.
- [do_runtime_in_place_kv_cache](https://github.com/openvinotoolkit/openvino/blob/792ddf38fe3da130c2b3e11662374ec9ca3a2624/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L1448)
- If sufficient memory is available to store the concatenated KV data of past and new tokens, it updates the VariableState's layout's padding to match the memory remaining after concatenation. Then, it sets the `kv_cache` primitive_inst’s `_can_be_optimized` flag to true.
- [kv_cache_inst::realloc_if_needed](https://github.com/openvinotoolkit/openvino/blob/792ddf38fe3da130c2b3e11662374ec9ca3a2624/src/plugins/intel_gpu/src/graph/primitive_inst.cpp#L699)
- If the `kv_cache` primitive_inst's `_can_be_optimized` is true
- It sets the `kv_cache` primitive_inst's output memory with the VariableState's memory.
- Otherwise, (i.e., `_can_be_optimized` is not true)
- It allocates a new output memory for `kv_cache` inst and then set it as the VariableState's memory. At this time, we allocate larger memory for `kv_cache` than actually needed at that time (i.e., [preallocate](https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_gpu/docs/dynamic_shape/memory_preallocation.md)), and the detail of the policy is described in the next subsection.
- [kv_cache_inst::execute::concat_stage](https://github.com/openvinotoolkit/openvino/blob/792ddf38fe3da130c2b3e11662374ec9ca3a2624/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp#L87)
- Basically, the kv_cache_inst concatenates the past and new KV data using two `concatenate` kernels.
- When the instance’s `_can_be_optimized` flag is false, it launches one concatenate kernel per input (two in total for past and new KV), with each kernel copying its input data into the target memory region.
- When the `_can_be_optimized` flag is true, execution of the `concatenate` (i.e., copy) of the first input (i.e., the past KV) is skipped, reducing the operation to a single concatenate kernel for the new KV input (see [link](https://github.com/openvinotoolkit/openvino/blob/792ddf38fe3da130c2b3e11662374ec9ca3a2624/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp#L501)).

Suppose that the KV cache has shape [batch, num_heads, seq_len, head_size]. For example, let's consider a case when the past KV cache's data shape is [1, 32, **17**, 64] and the new KV projection data shape is [1, 32, **1**, 64]. If the preallocated memory can hold a data with shape [1, 32, **128**, 64], `do_runtime_in_place_kv_cache` sets the upper padding of the `kv_cache` inst to [0, 0, **111**, 0]. At the execution, the past KV data remains unchanged, and the new token’s KV values are written into the remaining space along the concat axis. As a result, the effective data shape becomes [1, 32, **18**, 64], and the updated upper padding becomes [0, 0, **110**, 0].

### Preallocate memory

To achieve `in_place_kv_cache` and reduce host overhead due to the memory allocation, GPU plugin allocates larger memory for KV cache than actually needed. In the general case, we preallocate only enough memory for 10 iterations, but for the KV cache we allocate somewhat more (around 128).

However, we do not allocate same size for all layers, but we allocate different size. You can see how the allocation size is determined in this [function](https://github.com/openvinotoolkit/openvino/blob/fcb474cb3c38fbf1aa1faa3a133f3b3ec5c22f1c/src/plugins/intel_gpu/src/graph/kv_cache.cpp#L96).

We vary memory allocation across layers to prevent periodic peaks in overall memory usage. If every `kv_cache` instance used the same allocation size, their reallocation intervals would be same. When a reallocation occurs, the concatenation inside the `kv_cache` cannot be performed in place, so both the input and output buffers must coexist temporarily. This effectively doubles the required memory for that `kv_cache` inst. If all layers hit this reallocation point simultaneously, the combined effect produces a large memory spike. By staggering the allocation schedules across layers, we mitigate these synchronized peaks and significantly reduce overall memory pressure.
<!-- * ### Indirect SDPA (TBD)) -->
<!-- * ### KV Cache compression (TBD)) -->
33 changes: 33 additions & 0 deletions src/plugins/intel_gpu/docs/dynamic_shape/runtime_skip.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Runtime operation skip
## Description
When working with dynamic shapes, compilation-time optimization faces inherent limitations since shape information remains undefined until runtime. This creates a two-phase optimization opportunity: while certain operations cannot be optimized during the initial compilation phase due to unknown shapes, they become prime candidates for runtime optimization once concrete shape information materializes during inference execution.

Consider a 4D permute operation with the transformation order [0, 2, 1, 3]. During compilation, the input shapes are dynamic [-1, -1, -1, -1], therefore, any shape-based optimization is not applicable. However, there might be a second chance to optimize this operation during the runtime. Suppose the actual input shape resolves to [128, 1, 32, 64]. With this concrete information, the we can now recognize a critical insight: since dimension 1 has size 1, swapping dimensions 1 and 2 (as specified by the permute order [0, 2, 1, 3]) results in no actual data movement. The operation becomes essentially a metadata-only transformation—a simple reshape that requires no memory copying or data rearrangement.
This example demonstrates how runtime optimization can transform potentially expensive operations to be skipped, highlighting the value of deferred optimization strategies in dynamic computation graphs.

## Basic flow of runtime operation skip
1. **Relevant flags**
First, we need to set two flags for the program_node of such an operation, which we do not apply shape-based optimization during compilation but try runtime optimization with the shape.
- Static flags (Set during `mark_runtime_skippable_nodes` pass at compilation time)
- `program_node::optimized`
- This flag presents that this node is eligible for being optimized out, either at compilation time or runtime.
- This flag is set true for all optimization schemes, not limited to runtime skippability.
- `program_node::runtime_skippable`
- Indicates that this node can be optimized during runtime based on the shape.
- Dynamic flag (Set at runtime)
- `primitive_inst::_can_be_optimized`
- Indicates that this `primitive_inst` is actually optimized out at a certain execution

If `program_node::optimized` is true and `program_node::runtime_skippable` is false, it means that this node is *always* optimized out (i.e., compile-time optimization).
If both of the flags are set true, the node may be optimized out or not in the runtime, depending on the runtime shapes.
If program_node::optimized is false and program_node::runtime_skippable is true, it is an invalid combination.

As an example of using both flags, please refer to [memory_dependency_pass](https://github.com/openvinotoolkit/openvino/blob/aa6d3811e6dea93cb818ff483bf6c3ca849d4034/src/plugins/intel_gpu/src/graph/include/pass_manager.h#L313), which makes different decisions for dependency settings depending on whether a node is optimized at compile time or at runtime.

2. **Runtime optimization decision**
- Once the shape is updated in `primitive_inst::prepare_primitive()`, `do_runtime_skip_*node_type*` for each type of operation decides whether to skip the node at that execution or not.

3. **Caveats**
- Once the `primitive_inst::_can_be_optimized` is set true, the runtime will only update its metadata such as shape or padding information and skip the actual execution.
- Also, it needs to update the primitive_inst's output memory with its input memory. This is done by `update_output_memory()` called from `primitive_inst::on_execute()`.
- If you are adding a new type of skippable operation, please make sure that the primitive has `update_output_memory()` function implemented too.