diff --git a/modules/loaders.py b/modules/loaders.py index cd864e406d..7173ea67b2 100644 --- a/modules/loaders.py +++ b/modules/loaders.py @@ -12,6 +12,7 @@ 'alpha_value', 'compress_pos_emb', 'compute_dtype', + 'cache_type', 'quant_type', 'load_in_8bit', 'load_in_4bit', diff --git a/modules/shared.py b/modules/shared.py index 2e91f4d5a3..3e76301861 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -158,7 +158,7 @@ # Cache group = parser.add_argument_group('Cache') -group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.') +group.add_argument('--cache_type', type=str, default='fp16', help='KV cache type; valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4; transformers - fp16, quanto4, quanto2, hqq4, hqq2.') # DeepSpeed group = parser.add_argument_group('DeepSpeed') diff --git a/modules/text_generation.py b/modules/text_generation.py index 152b2b8df0..b9b617a5e0 100644 --- a/modules/text_generation.py +++ b/modules/text_generation.py @@ -11,6 +11,9 @@ import transformers from transformers import ( LogitsProcessorList, + QuantoQuantizedCache, + HQQQuantizedCache, + QuantizedCacheConfig, is_torch_npu_available, is_torch_xpu_available ) @@ -65,6 +68,32 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap logger.info("PROMPT=") print_prompt(question) + shared_cache=None + if generate_func == generate_reply_HF: + if shared.args.loader == 'Transformers': + if shared.args.cache_type in ['quanto4', 'quanto2']: + cache_config = QuantizedCacheConfig( + axis_key=0, + axis_value=0, + backend='quanto', + nbits=4 if shared.args.cache_type == 'quanto4' else 2, + device=get_device(), + compute_dtype=shared.args.compute_dtype + ) + shared_cache = QuantoQuantizedCache(cache_config=cache_config) + elif shared.args.cache_type in ['hqq4', 'hqq2']: + cache_config = QuantizedCacheConfig( + axis_key=1, + axis_value=1, + backend='hqq', + nbits=4 if shared.args.cache_type == 'hqq4' else 2, + device=get_device(), + compute_dtype=shared.args.compute_dtype + ) + shared_cache = HQQQuantizedCache(cache_config=cache_config) + + + # Prepare the input original_question = question if not is_chat: @@ -94,7 +123,7 @@ def _generate_reply(question, state, stopping_strings=None, is_chat=False, escap min_update_interval = 1 / state['max_updates_second'] # Generate - for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat): + for reply in generate_func(question, original_question, seed, state, stopping_strings, is_chat=is_chat, shared_cache=shared_cache): reply, stop_found = apply_stopping_strings(reply, all_stop_strings) if escape_html: reply = html.escape(reply) @@ -282,7 +311,7 @@ def get_reply_from_output_ids(output_ids, state=None, starting_from=0): return reply -def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False): +def generate_reply_HF(question, original_question, seed, state, stopping_strings=None, is_chat=False, shared_cache=None): if shared.args.loader == 'Transformers': clear_torch_cache() @@ -377,6 +406,10 @@ def generate_reply_HF(question, original_question, seed, state, stopping_strings generate_params['stopping_criteria'] = transformers.StoppingCriteriaList() generate_params['stopping_criteria'].append(_StopEverythingStoppingCriteria()) + if shared_cache: + generate_params['past_key_values'] = shared_cache + generate_params['use_cache'] = True + # Logits processor processor = state.get('logits_processor', LogitsProcessorList([])) if not isinstance(processor, LogitsProcessorList): @@ -460,7 +493,7 @@ def generate_with_streaming(**kwargs): return -def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False): +def generate_reply_custom(question, original_question, seed, state, stopping_strings=None, is_chat=False, shared_cache=None): """ For models that do not use the transformers library for sampling """ diff --git a/modules/ui_model_menu.py b/modules/ui_model_menu.py index 1264a9fd67..f0b76eb80b 100644 --- a/modules/ui_model_menu.py +++ b/modules/ui_model_menu.py @@ -91,7 +91,7 @@ def create_ui(): shared.gradio['hqq_backend'] = gr.Dropdown(label="hqq_backend", choices=["PYTORCH", "PYTORCH_COMPILE", "ATEN"], value=shared.args.hqq_backend) shared.gradio['n_ctx'] = gr.Number(label="n_ctx", precision=0, step=256, value=shared.args.n_ctx, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.') shared.gradio['max_seq_len'] = gr.Number(label='max_seq_len', precision=0, step=256, value=shared.args.max_seq_len, info='Context length. ⚠️ Lower this value if you can\'t load the model. Common values: 2048, 4096, 8192, 16384, 32768.') - shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4.') + shared.gradio['cache_type'] = gr.Dropdown(label="cache_type", choices=['fp16', 'q8_0', 'q4_0', 'fp8', 'q8', 'q6', 'q4', 'quanto4', 'quanto2', 'hqq4', 'hqq2'], value=shared.args.cache_type, info='Valid options: llama.cpp - fp16, q8_0, q4_0; ExLlamaV2 - fp16, fp8, q8, q6, q4l transformers - fp16, quanto4, quanto2, hqq4, hqq2.') shared.gradio['tensor_split'] = gr.Textbox(label='tensor_split', info='List of proportions to split the model across multiple GPUs. Example: 60,40') shared.gradio['gpu_split'] = gr.Textbox(label='gpu-split', info='Comma-separated list of VRAM (in GB) to use per GPU. Example: 20,7,7') shared.gradio['alpha_value'] = gr.Number(label='alpha_value', value=shared.args.alpha_value, precision=2, info='Positional embeddings alpha factor for NTK RoPE scaling. Recommended values (NTKv1): 1.75 for 1.5x context, 2.5 for 2x context. Use either this or compress_pos_emb, not both.') diff --git a/requirements.txt b/requirements.txt index d09f6bf582..c3e1c0b889 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,6 +24,7 @@ tensorboard transformers==4.49.* tqdm wandb +optimum-quanto>=0.2.6 # API SpeechRecognition==3.10.0