diff --git a/samples/python/text_generation/limit_checker.py b/samples/python/text_generation/limit_checker.py new file mode 100644 index 0000000000..66d8928286 --- /dev/null +++ b/samples/python/text_generation/limit_checker.py @@ -0,0 +1,241 @@ + +import gc +import os +import psutil +import csv +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional +from tqdm import tqdm + +from optimum.intel.openvino import OVModelForCausalLM +from openvino_genai import ContinuousBatchingPipeline, SchedulerConfig, GenerationResult, GenerationConfig, CacheEvictionConfig, AggregationMode +from openvino_tokenizers import convert_tokenizer +from openvino import serialize +from transformers import AutoTokenizer +import argparse + +import time +import logging +from huggingface_hub.utils import HfHubHTTPError +from subprocess import CalledProcessError # nosec B404 +from requests.exceptions import RequestException + +# Configure the logger +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +def retry_request(func, retries=5): + """ + Retries a function that makes a request up to a specified number of times. + + Parameters: + func (callable): The function to be retried. It should be a callable that makes a request. + retries (int): The number of retry attempts. Default is 5. + + Returns: + Any: The return value of the function `func` if it succeeds. + """ + network_error_patterns = [ + "ConnectionError", + "Timeout", + "Time-out", + "ServiceUnavailable", + "InternalServerError" + ] + + for attempt in range(retries): + try: + return func() + except (CalledProcessError, RequestException, HfHubHTTPError) as e: + if isinstance(e, CalledProcessError): + if any(pattern in e.stderr for pattern in network_error_patterns): + logger.warning(f"CalledProcessError occurred: {e.stderr}") + else: + raise e + if attempt < retries - 1: + timeout = 2 ** attempt + logger.info(f"Attempt {attempt + 1} failed. Retrying in {timeout} seconds.") + time.sleep(timeout) + else: + raise e + +def load_prompts_dataset(file_name : str) -> Dict[str, List[str]]: + TESTS_ROOT = Path('tests/python_tests') + file_path = TESTS_ROOT / 'data' / file_name + with open(file_path, 'r') as f: + return {"prompts": [s for s in f]} + +def load_samsum_dataset(file_name : str) -> Dict[str, List[str]]: + import json + retval = {"prompts": []} + with open(file_name, 'r') as json_file: + json_list = list(json_file) + for json_str in json_list: + result = json.loads(json_str) + retval["prompts"].append(result["prompt"]) + return retval + +def get_scheduler_config(num_kv_blocks: Optional[int]) -> SchedulerConfig: + scheduler_config = SchedulerConfig() + if num_kv_blocks is not None: + scheduler_config.num_kv_blocks = num_kv_blocks + scheduler_config.dynamic_split_fuse = True + scheduler_config.max_num_batched_tokens = 32 * num_kv_blocks + scheduler_config.max_num_seqs = 256 + scheduler_config.use_cache_eviction = False + return scheduler_config + +@dataclass +class ConvertedModel: + model: OVModelForCausalLM + tokenizer: AutoTokenizer + models_path: Path + + +def get_converted_model(base_model_path: Path, model_id: str): + model = retry_request(lambda: OVModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, load_in_8bit=False, compile=False, ov_config=get_default_llm_properties())) + tokenizer = retry_request(lambda: AutoTokenizer.from_pretrained(model_id)) + models_path = base_model_path / model_id + models_path.mkdir(parents=True, exist_ok=True) + model.save_pretrained(models_path) + ov_tokenizer, ov_detokenizer = convert_tokenizer(tokenizer, with_detokenizer=True, skip_special_tokens=True) + serialize(ov_tokenizer, models_path / "openvino_tokenizer.xml") + serialize(ov_detokenizer, models_path / "openvino_detokenizer.xml") + converted_model = ConvertedModel(model, tokenizer, models_path) + return converted_model + + +import openvino.properties.hint as hints +import openvino.properties as props +import openvino as ov + +def get_default_llm_properties(): + return { + hints.inference_precision : ov.Type.f32, + hints.kv_cache_precision : ov.Type.f16, + } + +def run_and_write_metrics(model, prompt, generation_config, report_file): + result: GenerationResult = model_cb_opt.generate([prompt], generation_config=[generation_config]) + + pipeline_opt_metrics = model_cb_opt.get_metrics() + rss_usage_gb = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 3 + result_length = len(result[0].m_generation_ids[0]) + print(f"avg_cache_usage:{pipeline_opt_metrics.avg_cache_usage:.2f}% max_cache_usage:{pipeline_opt_metrics.max_cache_usage:.2f}% rss_usage:{rss_usage_gb:.3f} GB") + print(f"result length: {result_length}") + print() + + if report_file is not None: + with open(report_file, 'a') as f: + csv_writer = csv.writer(f) + csv_writer.writerow([generation_config.max_new_tokens - 1, result_length, pipeline_opt_metrics.avg_cache_usage, pipeline_opt_metrics.max_cache_usage, rss_usage_gb]) + return pipeline_opt_metrics.max_cache_usage + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--eviction_on", action='store_true', help="Whether to apply cache eviction") + parser.add_argument("--model", type=str, help="Model ID") + parser.add_argument("--num_kv_blocks", type=int, help='Number of blocks to statically pre-allocate in cache.' + 'If left unspecified, will allocate dynamically to accomodate the generation length.') + parser.add_argument("--report", type=str, help="File name for CSV-formatted export of limit search data") + parser.add_argument("--mode", type=str, nargs='?', choices=['gen_length', 'gen_throughput'], required=True) + parser.add_argument("--data", type=str, help="Dataset jsonl file") + parser.add_argument("--timeout", type=int, help="Maximum time allowed for a single round of generation in the `gen_length` mode", default=120) + parser.add_argument("--device", type=str, help="Device for model inference", default="CPU") + + args = parser.parse_args() + seqs_per_request = 1 + num_kv_blocks = args.num_kv_blocks + + scheduler_config_opt = get_scheduler_config(num_kv_blocks) + if args.eviction_on: + scheduler_config_opt.use_cache_eviction = True + print("Eviction is ON") + else: + print("Eviction is OFF") + + base_model_path = Path("limit_checker_models") + converted_model = get_converted_model(base_model_path, args.model) + models_path = converted_model.models_path + model_cb_opt = ContinuousBatchingPipeline(models_path, scheduler_config_opt, args.device, {}, get_default_llm_properties()) + + tokenizer = converted_model.tokenizer + if args.mode == "gen_length": + data_dict = load_prompts_dataset('long_prompts.txt') + prompt = data_dict["prompts"][0] + + generation_length = 1 + + if args.report is not None: + with open(args.report, 'w') as f: + csv_writer = csv.writer(f) + csv_writer.writerow(['generation_length', 'result_length', 'avg_cache_usage_%', 'max_cache_usage_%', 'rss_usage_gb']) + + + while True: + gc.collect() + generation_config = GenerationConfig() # expecting default greedy sampling + generation_config.num_return_sequences = 1 + generation_config.max_new_tokens = generation_length + 1 + generation_config.apply_chat_template = False + generation_config.ignore_eos = True + print(f"generation_length:{generation_length} ", sep='') + + start = time.time() + max_cache_usage = run_and_write_metrics(model_cb_opt, prompt, generation_config, args.report) + end = time.time() + if (end - start) > args.timeout: + print("Maximum generation time reached") + break + elif max_cache_usage == 100: + print("Cache size exhausted") + break + + generation_length *= 2 + + del data_dict + elif args.mode == "gen_throughput": + dataset = load_samsum_dataset(args.data) + prompt_throughput = 1 + prompt_left_bound = prompt_throughput + prompt_right_bound = None + is_right_bound = False + + while True: + gc.collect() + generation_config = GenerationConfig() # expecting default greedy sampling + generation_config.num_return_sequences = 1 + generation_config.apply_chat_template = False + prompt_subset = dataset["prompts"][:prompt_throughput] + print(f"prompt_throughput {prompt_throughput}") + result: GenerationResult = model_cb_opt.generate(prompt_subset, generation_config=[generation_config] * len(prompt_subset)) + + pipeline_opt_metrics = model_cb_opt.get_metrics() + rss_usage_gb = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 3 + print(f"avg_cache_usage:{pipeline_opt_metrics.avg_cache_usage:.2f}% max_cache_usage:{pipeline_opt_metrics.max_cache_usage:.2f}% rss_usage:{rss_usage_gb:.3f} GB") + print() + + max_cache_usage = pipeline_opt_metrics.max_cache_usage + + if max_cache_usage == 100.0 and not is_right_bound: + is_right_bound = True + prompt_right_bound = prompt_throughput + + if not is_right_bound: + prompt_left_bound = prompt_throughput + prompt_throughput *= 2 + else: + if max_cache_usage == 100.0: + prompt_right_bound = prompt_throughput + elif max_cache_usage < 100.0: + prompt_left_bound = prompt_throughput + prompt_throughput = (prompt_left_bound + prompt_right_bound) // 2 + + if (prompt_right_bound - prompt_left_bound <= 1): + break + + + print(f"Approximate highest throughput: {prompt_throughput} prompts") + diff --git a/site/docs/concepts/optimization-techniques/kvcache-eviction-algorithm.md b/site/docs/concepts/optimization-techniques/kvcache-eviction-algorithm.md index 9e8820eb7d..95376a45d6 100644 --- a/site/docs/concepts/optimization-techniques/kvcache-eviction-algorithm.md +++ b/site/docs/concepts/optimization-techniques/kvcache-eviction-algorithm.md @@ -3,3 +3,60 @@ sidebar_position: 2 --- # KVCache Token Eviction Algorithm + + +## Overview +The cache eviction algorithm is designed to manage KV (Key-Value) cache memory for large language models (LLMs) during text generation. It determines which blocks of tokens should be evicted from the KV cache based on importance scores calculated from attention scores across different attention layers. + +## Conceptual Model +The KV cache for each sequence is divided into three logical areas: + +![KV cache layout with cache eviction](/img/kv-cache-areas-diagram.svg) + +* Start Area: Initial tokens that are never evicted +* Evictable Area: Tokens that can be evicted based on importance scores +* Recent Area: Most recent tokens that are preserved (not evicted while in this area, but naturally migrating toward the evictable area as the text generation goes on) + +The sizes of all three areas can be configured by modifying corresponding fields in a `CacheEvictionConfig` struct, which itself is a part of the pipeline-wide `SchedulerConfig`. +As the generation starts, the blocks in respective logical areas are filled token-by-token, and once at least one block past the "recent" area is filled, eviction may take place. +The tokens are evicted based on accumulated importance scores following the [H2O](https://arxiv.org/abs/2306.14048) approach. +The scores are accumulated throughout the entire generation process and their weighting may be changed by adjusting the `CacheEvictionConfig.aggregation_mode` parameter. +Eviction occurs with a block-wise granularity, and only the completely filled blocks from the "evictable" area are evicted. +By default the start area is 32 tokens, evictable area is 512 tokens and recent area is 128 tokens, which amounts to a total maximum cache usage by sequence during the generation phase of 672 tokens. + +This approach allows LLMs to handle long sequences efficiently by keeping the most contextually important tokens in the cache while evicting those of lesser importance. +The downside of the eviction procedure is potential loss of generation accuracy, since the cache no longer contains the entire context for the generation, but only the most "important" token blocks. +The user can adjust the individual sizes of the eviction sub-areas to hit the optimal point of accuracy/memory usage tradeoff in their particular case. + +Note that currently the eviction only starts after the full prompt has been processed, i.e. no eviction takes place during the prefill phase. +This means that for longer prompt sizes the maximum cache usage may exceed the limit defined by the `CacheEvictionConfig` parameters. + +After the prefill phase, however, the maximum cache occupancy for each sequence currently being processed is strictly limited by the combined sizes of the 3 areas described above. +`CacheEvictionConfig.get_max_cache_size_after_eviction()` can be queried to get this cache size limit in tokens. + + +## Sample - impact of cache eviction on possible generation length and prompt throughput +[limit_checker.py](https://github.com/openvinotoolkit/openvino.genai/tree/master/samples/python/text_generation/limit_checker.py) can be used to visualize the impact of the cache eviction algorithm on the end performance of the generation pipeline. +The script is paramaterized to allow specifying own model (by its `huggingface_hub` ID) and the base cache size. + +With `--mode gen_length`, the script will run the generation pipeline with increasing requested length of generation until it either hits 100% maximum cache usage or times out. +With cache eviction disabled, the pipeline will eventually exhaust the cache size, and the generation length will be capped at the output token count determined by the base cache size. +With eviction enabled, however, the pipeline is able to generate sequences of arbitrary length (as long as the cache size is at least `max(prompt_size, max_cache_size_after_eviction)`, and the script will instead finish with a timeout. + +With `--mode gen_throughput`, the script will run a binary search to determine the minimum number of concurrently processed sequences to hit the 100% cache utilization. + + +## (Optional) Cache Rotation +By default, no additional cache modification is performed during eviction. +Most LLMs employ some kind of positional embedding at some point in the inferencing, which effectively becomes associated with each per-token KV cache vector as well. +The popular RoPE positional embedding is more or less continuous in the linear space of the token positions, but when token eviction takes place, the continuity of the remaining blocks is disrupted. +This may impact the ability of the model to correctly recognize the relative positions of the remaining blocks and degrade the generation accuracy. + +Cache rotation seeks to alleviate this by "re-rotating" corresponding blocks so that the blocks that remain after each eviction are once again "continuous" in terms of the effective RoPE embedding. +It can be enabled by setting the `CacheEvictionConfig.apply_rotation` field to `true` (default is `false`). + +## Current limitations + +* Cache rotation is only targeted for the regular, linear LLaMa-like RoPE application and may degrade accuracy on models that use other RoPE schemes. + +* Cache rotation is currently only supported for the models with uniform V embedding sizes across the layers. diff --git a/site/static/img/kv-cache-areas-diagram.svg b/site/static/img/kv-cache-areas-diagram.svg new file mode 100644 index 0000000000..822f33c2c4 --- /dev/null +++ b/site/static/img/kv-cache-areas-diagram.svg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e2fa45a69b4db6e8293fd8e1da712c2970237ac98aab99d4b0d729379bbe49c6 +size 7143