From 3b627e1a937a316fdfba98450e38c8054f1a6130 Mon Sep 17 00:00:00 2001 From: dharmendrac Date: Thu, 9 Nov 2023 19:40:39 +0530 Subject: [PATCH] Support for Llama2 models inference via NeMo Framework Inference Container using TRT-LLM and Triton Inference Server --- .../LICENSE-Apache-2.0.txt | 177 ++ RetrievalAugmentedGeneration/LICENSE.md | 14 + .../llm-inference-server/Dockerfile | 22 + .../conversion_scripts/llama/build.py | 776 +++++++++ .../conversion_scripts/llama/weight.py | 1446 +++++++++++++++++ .../ensemble_models/gptnext | 1 + .../ensemble_models/llama/ensemble/1/.gitkeep | 0 .../llama/ensemble/config.pbtxt | 228 +++ .../llama/postprocessing/1/model.py | 173 ++ .../llama/postprocessing/config.pbtxt | 50 + .../llama/preprocessing/1/model.py | 244 +++ .../llama/preprocessing/config.pbtxt | 65 + .../llama/tensorrt_llm/1/.gitkeep | 0 .../llama/tensorrt_llm/config.pbtxt.j2 | 208 +++ .../model_server/__init__.py | 129 ++ .../model_server/__main__.py | 196 +++ .../model_server/conversion/__init__.py | 73 + .../model_server/conversion/llama.py | 96 ++ .../model_server/conversion/nemo.py | 71 + .../model_server/errors.py | 32 + .../model_server/model.py | 246 +++ .../model_server/server.py | 153 ++ .../model_server_client/trt_llm.py | 544 +++++++ .../llm-inference-server/requirements.txt | 7 + SECURITY.md | 24 + 25 files changed, 4975 insertions(+) create mode 100644 RetrievalAugmentedGeneration/LICENSE-Apache-2.0.txt create mode 100644 RetrievalAugmentedGeneration/LICENSE.md create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/Dockerfile create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/build.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/weight.py create mode 120000 RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/gptnext create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/1/.gitkeep create mode 100755 RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/config.pbtxt create mode 100755 RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/1/model.py create mode 100755 RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/config.pbtxt create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/1/model.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/config.pbtxt create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/1/.gitkeep create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/config.pbtxt.j2 create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/model_server/__init__.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/model_server/__main__.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/__init__.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/llama.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/nemo.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/model_server/errors.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/model_server/model.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/model_server/server.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/model_server_client/trt_llm.py create mode 100644 RetrievalAugmentedGeneration/llm-inference-server/requirements.txt create mode 100644 SECURITY.md diff --git a/RetrievalAugmentedGeneration/LICENSE-Apache-2.0.txt b/RetrievalAugmentedGeneration/LICENSE-Apache-2.0.txt new file mode 100644 index 00000000..08017da8 --- /dev/null +++ b/RetrievalAugmentedGeneration/LICENSE-Apache-2.0.txt @@ -0,0 +1,177 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS diff --git a/RetrievalAugmentedGeneration/LICENSE.md b/RetrievalAugmentedGeneration/LICENSE.md new file mode 100644 index 00000000..89bdbd99 --- /dev/null +++ b/RetrievalAugmentedGeneration/LICENSE.md @@ -0,0 +1,14 @@ +SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: Apache-2.0 + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/llm-inference-server/Dockerfile b/RetrievalAugmentedGeneration/llm-inference-server/Dockerfile new file mode 100644 index 00000000..43fa8e79 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/Dockerfile @@ -0,0 +1,22 @@ +ARG BASE_IMAGE_URL=nvcr.io/ea-bignlp/beta-inf-prerelease/infer +ARG BASE_IMAGE_TAG=23.10.v3 + +FROM ${BASE_IMAGE_URL}:${BASE_IMAGE_TAG} + +ENV LD_LIBRARY_PATH=/opt/tritonserver/backends/tensorrtllm:$LD_LIBRARY_PATH + +# install model-server automation +COPY conversion_scripts /opt/conversion_scripts +COPY ensemble_models /opt/ensemble_models +COPY model_server /opt/model_server +COPY model_server_client /opt/model_server_client +RUN --mount=type=bind,source=requirements.txt,target=/opt/requirements.txt \ + pip install --no-cache-dir -r /opt/requirements.txt + +# Create basic directories + +RUN mkdir /model && chmod 1777 /model && \ + mkdir -p /home/triton-server && chown 1000:1000 /home/triton-server && chmod 700 /home/triton-server + +WORKDIR /opt +ENTRYPOINT ["/usr/bin/python3", "-m", "model_server"] diff --git a/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/build.py b/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/build.py new file mode 100644 index 00000000..2a57d4cd --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/build.py @@ -0,0 +1,776 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import json +import os +import time +from pathlib import Path + +import tensorrt as trt +import tensorrt_llm +import torch +import torch.multiprocessing as mp +from tensorrt_llm._utils import str_dtype_to_trt +from tensorrt_llm.builder import Builder +from tensorrt_llm.layers.attention import PositionEmbeddingType +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models import ( + fp8_quantize, + smooth_quantize, + weight_only_groupwise_quantize, + weight_only_quantize, +) +from tensorrt_llm.network import net_guard +from tensorrt_llm.plugin.plugin import ContextFMHAType +from tensorrt_llm.quantization import QuantMode +from transformers import LlamaConfig, LlamaForCausalLM +from weight import ( + get_scaling_factors, + load_from_awq_llama, + load_from_binary, + load_from_gptq_llama, + load_from_hf_llama, + load_from_meta_llama, +) + +from weight import parse_ft_config # isort:skip + +MODEL_NAME = "llama" + +# 2 routines: get_engine_name, serialize_engine +# are direct copy from gpt example, TODO: put in utils? + +import onnx +import tensorrt as trt +from onnx import TensorProto, helper + + +def trt_dtype_to_onnx(dtype): + if dtype == trt.float16: + return TensorProto.DataType.FLOAT16 + elif dtype == trt.float32: + return TensorProto.DataType.FLOAT + elif dtype == trt.int32: + return TensorProto.DataType.INT32 + else: + raise TypeError("%s is not supported" % dtype) + + +def to_onnx(network, path): + inputs = [] + for i in range(network.num_inputs): + network_input = network.get_input(i) + inputs.append( + helper.make_tensor_value_info( + network_input.name, + trt_dtype_to_onnx(network_input.dtype), + list(network_input.shape), + ) + ) + + outputs = [] + for i in range(network.num_outputs): + network_output = network.get_output(i) + outputs.append( + helper.make_tensor_value_info( + network_output.name, + trt_dtype_to_onnx(network_output.dtype), + list(network_output.shape), + ) + ) + + nodes = [] + for i in range(network.num_layers): + layer = network.get_layer(i) + layer_inputs = [] + for j in range(layer.num_inputs): + ipt = layer.get_input(j) + if ipt is not None: + layer_inputs.append(layer.get_input(j).name) + layer_outputs = [layer.get_output(j).name for j in range(layer.num_outputs)] + nodes.append( + helper.make_node( + str(layer.type), + name=layer.name, + inputs=layer_inputs, + outputs=layer_outputs, + domain="com.nvidia", + ) + ) + + onnx_model = helper.make_model( + helper.make_graph(nodes, "attention", inputs, outputs, initializer=None), + producer_name="NVIDIA", + ) + onnx.save(onnx_model, path) + + +def get_engine_name(model, dtype, tp_size, pp_size, rank): + if pp_size == 1: + return "{}_{}_tp{}_rank{}.engine".format(model, dtype, tp_size, rank) + return "{}_{}_tp{}_pp{}_rank{}.engine".format(model, dtype, tp_size, pp_size, rank) + + +def serialize_engine(engine, path): + logger.info(f"Serializing engine to {path}...") + tik = time.time() + with open(path, "wb") as f: + f.write(bytearray(engine)) + tok = time.time() + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + logger.info(f"Engine serialized. Total time: {t}") + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--world_size", type=int, default=1) + parser.add_argument("--tp_size", type=int, default=1) + parser.add_argument("--pp_size", type=int, default=1) + parser.add_argument("--model_dir", type=str, default=None) + parser.add_argument("--ft_model_dir", type=str, default=None) + parser.add_argument("--meta_ckpt_dir", type=str, default=None) + parser.add_argument("--quant_ckpt_path", type=str, default=None) + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=["float32", "bfloat16", "float16"], + ) + parser.add_argument( + "--timing_cache", + type=str, + default="model.cache", + help="The path of to read timing cache from, will be ignored if the file does not exist", + ) + parser.add_argument("--log_level", type=str, default="info") + parser.add_argument("--vocab_size", type=int, default=32000) + parser.add_argument("--n_layer", type=int, default=32) + parser.add_argument("--n_positions", type=int, default=2048) + parser.add_argument("--n_embd", type=int, default=4096) + parser.add_argument("--n_head", type=int, default=32) + parser.add_argument("--n_kv_head", type=int, default=None) + parser.add_argument("--multiple_of", type=int, default=256) + parser.add_argument("--ffn_dim_multiplier", type=float, default=1.0) + parser.add_argument("--inter_size", type=int, default=None) + parser.add_argument("--hidden_act", type=str, default="silu") + parser.add_argument("--rms_norm_eps", type=float, default=1e-06) + parser.add_argument("--max_batch_size", type=int, default=8) + parser.add_argument("--max_input_len", type=int, default=2048) + parser.add_argument("--max_output_len", type=int, default=512) + parser.add_argument("--max_beam_width", type=int, default=1) + parser.add_argument("--rotary_base", type=float, default=10000.0) + parser.add_argument("--rotary_scaling", nargs=2, type=str, default=None) + parser.add_argument( + "--use_gpt_attention_plugin", + nargs="?", + const="float16", + type=str, + default=False, + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument( + "--use_gemm_plugin", + nargs="?", + const="float16", + type=str, + default=False, + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument( + "--use_rmsnorm_plugin", + nargs="?", + const="float16", + type=str, + default=False, + choices=["float16", "float32", "bfloat16"], + ) + parser.add_argument("--parallel_build", default=False, action="store_true") + parser.add_argument("--enable_context_fmha", default=False, action="store_true") + parser.add_argument( + "--enable_context_fmha_fp32_acc", default=False, action="store_true" + ) + parser.add_argument("--visualize", default=False, action="store_true") + parser.add_argument("--enable_debug_output", default=False, action="store_true") + parser.add_argument("--gpus_per_node", type=int, default=8) + parser.add_argument("--builder_opt", type=int, default=None) + parser.add_argument( + "--output_dir", + type=str, + default="llama_outputs", + help="The path to save the serialized engine files, timing cache file and model configs", + ) + parser.add_argument("--remove_input_padding", default=False, action="store_true") + + # Arguments related to the quantization of the model. + parser.add_argument( + "--use_smooth_quant", + default=False, + action="store_true", + help="Use the SmoothQuant method to quantize activations and weights for the various GEMMs." + "See --per_channel and --per_token for finer-grained quantization options.", + ) + parser.add_argument( + "--per_channel", + default=False, + action="store_true", + help="By default, we use a single static scaling factor for the GEMM's result. " + "per_channel instead uses a different static scaling factor for each channel. " + "The latter is usually more accurate, but a little slower.", + ) + parser.add_argument( + "--per_token", + default=False, + action="store_true", + help="By default, we use a single static scaling factor to scale activations in the int8 range. " + "per_token chooses at run time, and for each token, a custom scaling factor. " + "The latter is usually more accurate, but a little slower.", + ) + parser.add_argument( + "--per_group", + default=False, + action="store_true", + help="By default, we use a single static scaling factor to scale weights in the int4 range. " + "per_group chooses at run time, and for each group, a custom scaling factor. " + "The flag is built for GPTQ/AWQ quantization.", + ) + parser.add_argument( + "--group_size", + type=int, + default=128, + help="Group size used in GPTQ/AWQ quantization.", + ) + parser.add_argument( + "--int8_kv_cache", + default=False, + action="store_true", + help="By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV", + ) + parser.add_argument( + "--use_parallel_embedding", + action="store_true", + default=False, + help="By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled", + ) + parser.add_argument( + "--embedding_sharding_dim", + type=int, + default=1, # Meta does TP on hidden dim + choices=[0, 1], + help="By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). " + "To shard it along hidden dimension, set embedding_sharding_dim=1" + "Note: embedding sharing is only enabled when embedding_sharding_dim = 0", + ) + parser.add_argument( + "--enable_fp8", + default=False, + action="store_true", + help="Use FP8 Linear layer for Attention QKV/Dense and MLP.", + ) + parser.add_argument( + "--fp8_kv_cache", + default=False, + action="store_true", + help="By default, we use dtype for KV cache. fp8_kv_cache chooses int8 quantization for KV", + ) + parser.add_argument( + "--quantized_fp8_model_path", + type=str, + default=None, + help="Path of a quantized model checkpoint in .npz format", + ) + parser.add_argument( + "--use_weight_only", + default=False, + action="store_true", + help="Quantize weights for the various GEMMs to INT4/INT8." + "See --weight_only_precision to set the precision", + ) + parser.add_argument( + "--weight_only_precision", + const="int8", + type=str, + nargs="?", + default="int8", + choices=["int8", "int4", "int4_awq", "int4_gptq"], + help="Define the precision for the weights when using weight-only quantization." + "You must also use --use_weight_only for that argument to have an impact.", + ) + parser.add_argument( + "--use_inflight_batching", + action="store_true", + default=False, + help="Activates inflight batching mode of gptAttentionPlugin.", + ) + parser.add_argument( + "--paged_kv_cache", + action="store_true", + default=False, + help="By default we use contiguous KV cache. By setting this flag you enable paged KV cache", + ) + parser.add_argument( + "--tokens_per_block", + type=int, + default=64, + help="Number of tokens per block in paged KV cache", + ) + parser.add_argument( + "--max_num_tokens", + type=int, + default=None, + help="Define the max number of tokens supported by the engine", + ) + parser.add_argument( + "--strongly_typed", + default=False, + action="store_true", + help="This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.", + ) + parser.add_argument( + "--use_custom_all_reduce", + action="store_true", + help="Activates latency-optimized algorithm for all-reduce instead of NCCL.", + ) + + args = parser.parse_args() + tensorrt_llm.logger.set_level(args.log_level) + + assert not ( + args.use_smooth_quant and args.use_weight_only + ), "You cannot enable both SmoothQuant and INT8 weight-only together." + + if not args.remove_input_padding: + if args.use_gpt_attention_plugin: + logger.warning( + f"It is recommended to specify --remove_input_padding when using GPT attention plugin" + ) + + if args.use_inflight_batching: + if not args.use_gpt_attention_plugin: + args.use_gpt_attention_plugin = "float16" + logger.info( + f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'" + ) + if not args.remove_input_padding: + args.remove_input_padding = True + logger.info("Using remove input padding for inflight batching mode.") + if not args.paged_kv_cache: + args.paged_kv_cache = True + logger.info("Using paged KV cache for inflight batching mode.") + + if args.use_smooth_quant: + args.quant_mode = QuantMode.use_smooth_quant(args.per_token, args.per_channel) + elif args.use_weight_only: + if args.per_group: + args.quant_mode = QuantMode.from_description( + quantize_weights=True, + quantize_activations=False, + per_token=False, + per_channel=False, + per_group=True, + use_int4_weights=True, + ) + else: + args.quant_mode = QuantMode.use_weight_only( + args.weight_only_precision == "int4" + ) + else: + args.quant_mode = QuantMode(0) + + if args.int8_kv_cache: + args.quant_mode = args.quant_mode.set_int8_kv_cache() + elif args.fp8_kv_cache: + args.quant_mode = args.quant_mode.set_fp8_kv_cache() + if args.enable_fp8: + args.quant_mode = args.quant_mode.set_fp8_qdq() + + if args.rotary_scaling is not None: + rotary_scaling = { + "type": args.rotary_scaling[0], + "factor": float(args.rotary_scaling[1]), + } + assert rotary_scaling["type"] in ["linear", "dynamic"] + assert rotary_scaling["factor"] > 1.0 + args.rotary_scaling = rotary_scaling + if rotary_scaling["type"] == "dynamic": + assert not args.remove_input_padding, "TODO: Not supported yet" + + # Since gpt_attenttion_plugin is the only way to apply RoPE now, + # force use the plugin for now with the correct data type. + args.use_gpt_attention_plugin = args.dtype + if args.model_dir is not None: + hf_config = LlamaConfig.from_pretrained(args.model_dir) + args.inter_size = ( + hf_config.intermediate_size + ) # override the inter_size for LLaMA + args.n_embd = hf_config.hidden_size + args.n_head = hf_config.num_attention_heads + if hasattr(hf_config, "num_key_value_heads"): + args.n_kv_head = hf_config.num_key_value_heads + args.n_layer = hf_config.num_hidden_layers + args.n_positions = hf_config.max_position_embeddings + args.vocab_size = hf_config.vocab_size + args.hidden_act = hf_config.hidden_act + args.rms_norm_eps = hf_config.rms_norm_eps + elif args.meta_ckpt_dir is not None: + with open(Path(args.meta_ckpt_dir, "params.json")) as fp: + meta_config: dict = json.load(fp) + args.n_embd = meta_config["dim"] + args.n_head = meta_config["n_heads"] + args.n_layer = meta_config["n_layers"] + args.n_kv_head = meta_config.get("n_kv_heads", args.n_head) + args.multiple_of = meta_config["multiple_of"] + args.ffn_dim_multiplier = meta_config.get("ffn_dim_multiplier", 1) + n_embd = int(4 * args.n_embd * 2 / 3) + args.inter_size = args.multiple_of * ( + (int(n_embd * args.ffn_dim_multiplier) + args.multiple_of - 1) + // args.multiple_of + ) + args.rms_norm_eps = meta_config["norm_eps"] + elif args.ft_model_dir is not None: + ( + n_embd, + n_head, + n_layer, + n_positions, + vocab_size, + hidden_act, + inter_size, + n_kv_head, + ) = parse_ft_config(Path(args.ft_model_dir) / "config.ini") + args.inter_size = inter_size # override the inter_size for LLaMA + args.n_kv_head = n_kv_head + args.n_embd = n_embd + args.n_head = n_head + args.n_layer = n_layer + args.n_positions = n_positions + args.vocab_size = vocab_size + args.hidden_act = hidden_act + args.rms_norm_eps = 1e-06 + logger.warning("Set rms_norm_eps to 1e-06 directly.") + assert args.use_gpt_attention_plugin, "LLaMa must use gpt attention plugin" + if args.n_kv_head is None: + args.n_kv_head = args.n_head + elif args.n_kv_head != args.n_head: + assert ( + args.n_head % args.n_kv_head + ) == 0, "MQA/GQA requires the number of heads to be divisible by the number of K/V heads." + assert (args.n_kv_head % args.tp_size) == 0 or ( + args.tp_size % args.n_kv_head + ) == 0, ( + "MQA/GQA requires either the number of K/V heads to be divisible by the tensor parallelism size OR " + "the tensor parallelism size to be divisible by the number of K/V heads." + ) + + if args.dtype == "bfloat16": + assert args.use_gemm_plugin, "Please use gemm plugin when dtype is bfloat16" + + assert args.pp_size * args.tp_size == args.world_size + + if args.max_num_tokens is not None: + assert args.enable_context_fmha + + if args.inter_size is None: + # this should not be need when loading a real model + # but it is helpful when creating a dummy model without loading any real weights + n_embd = int(4 * args.n_embd * 2 / 3) + args.inter_size = args.multiple_of * ( + (int(n_embd * args.ffn_dim_multiplier) + args.multiple_of - 1) + // args.multiple_of + ) + logger.info(f"Setting inter_size to {args.inter_size}.") + + return args + + +def build_rank_engine( + builder: Builder, + builder_config: tensorrt_llm.builder.BuilderConfig, + engine_name, + rank, + args, +): + """ + @brief: Build the engine on the given rank. + @param rank: The rank to build the engine. + @param args: The cmd line arguments. + @return: The built engine. + """ + dtype = str_dtype_to_trt(args.dtype) + mapping = Mapping( + world_size=args.world_size, + rank=rank, + tp_size=args.tp_size, + pp_size=args.pp_size, + ) + + assert ( + args.n_layer % args.pp_size == 0 + ), f"num_layers {args.n_layer} must be a multiple of pipeline parallelism size {args.pp_size}" + + # Initialize Module + tensorrt_llm_llama = tensorrt_llm.models.LLaMAForCausalLM( + num_layers=args.n_layer, + num_heads=args.n_head, + num_kv_heads=args.n_kv_head, + hidden_size=args.n_embd, + vocab_size=args.vocab_size, + hidden_act=args.hidden_act, + max_position_embeddings=args.n_positions, + dtype=dtype, + mlp_hidden_size=args.inter_size, + position_embedding_type=PositionEmbeddingType.rope_gpt_neox, + mapping=mapping, + rotary_base=args.rotary_base, + rotary_scaling=args.rotary_scaling, + use_parallel_embedding=args.use_parallel_embedding, + embedding_sharding_dim=args.embedding_sharding_dim, + quant_mode=args.quant_mode, + rms_norm_eps=args.rms_norm_eps, + ) + if args.use_smooth_quant: + tensorrt_llm_llama = smooth_quantize(tensorrt_llm_llama, args.quant_mode) + elif args.use_weight_only: + if args.weight_only_precision == "int8": + tensorrt_llm_llama = weight_only_quantize( + tensorrt_llm_llama, args.quant_mode + ) + elif args.weight_only_precision == "int4": + tensorrt_llm_llama = weight_only_quantize( + tensorrt_llm_llama, args.quant_mode + ) + elif args.weight_only_precision == "int4_awq": + tensorrt_llm_llama = weight_only_groupwise_quantize( + model=tensorrt_llm_llama, + quant_mode=args.quant_mode, + group_size=args.group_size, + zero=False, + pre_quant_scale=True, + exclude_modules=[], + ) + elif args.weight_only_precision == "int4_gptq": + tensorrt_llm_llama = weight_only_groupwise_quantize( + model=tensorrt_llm_llama, + quant_mode=args.quant_mode, + group_size=args.group_size, + zero=True, + pre_quant_scale=False, + ) + elif args.enable_fp8 or args.fp8_kv_cache: + logger.info(f"Loading scaling factors from " f"{args.quantized_fp8_model_path}") + quant_scales = get_scaling_factors( + args.quantized_fp8_model_path, + num_layers=args.n_layer, + quant_mode=args.quant_mode, + ) + tensorrt_llm_llama = fp8_quantize( + tensorrt_llm_llama, quant_mode=args.quant_mode, quant_scales=quant_scales + ) + if args.per_group: + load_func = ( + load_from_awq_llama + if args.weight_only_precision == "int4_awq" + else load_from_gptq_llama + ) + load_func( + tensorrt_llm_llama=tensorrt_llm_llama, + quant_ckpt_path=args.quant_ckpt_path, + mapping=mapping, + dtype=args.dtype, + ) + elif args.meta_ckpt_dir is not None: + load_from_meta_llama( + tensorrt_llm_llama, args.meta_ckpt_dir, mapping, args.dtype + ) + elif args.model_dir is not None: + logger.info(f"Loading HF LLaMA ... from {args.model_dir}") + tik = time.time() + hf_llama = LlamaForCausalLM.from_pretrained( + args.model_dir, + device_map={"model": "cpu", "lm_head": "cpu"}, # Load to CPU memory + torch_dtype="auto", + ) + tok = time.time() + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + logger.info(f"HF LLaMA loaded. Total time: {t}") + load_from_hf_llama( + tensorrt_llm_llama, hf_llama, mapping=mapping, dtype=args.dtype + ) + del hf_llama + elif args.ft_model_dir is not None: + load_from_binary( + tensorrt_llm_llama, + args.ft_model_dir, + mapping, + fp16=(args.dtype == "float16"), + multi_query_mode=(args.n_kv_head != args.n_head), + ) + + # Module -> Network + network = builder.create_network() + network.trt_network.name = engine_name + if args.use_gpt_attention_plugin: + network.plugin_config.set_gpt_attention_plugin( + dtype=args.use_gpt_attention_plugin + ) + if args.use_gemm_plugin: + network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin) + if args.use_rmsnorm_plugin: + network.plugin_config.set_rmsnorm_plugin(dtype=args.use_rmsnorm_plugin) + + # Quantization plugins. + if args.use_smooth_quant: + network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype) + network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype) + network.plugin_config.set_quantize_tensor_plugin() + network.plugin_config.set_quantize_per_token_plugin() + assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc) + if args.enable_context_fmha: + network.plugin_config.set_context_fmha(ContextFMHAType.enabled) + if args.enable_context_fmha_fp32_acc: + network.plugin_config.set_context_fmha(ContextFMHAType.enabled_with_fp32_acc) + if args.use_weight_only: + if args.per_group: + network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin( + dtype="float16" + ) + else: + network.plugin_config.set_weight_only_quant_matmul_plugin(dtype="float16") + if args.world_size > 1: + network.plugin_config.set_nccl_plugin(args.dtype, args.use_custom_all_reduce) + if args.remove_input_padding: + network.plugin_config.enable_remove_input_padding() + if args.paged_kv_cache: + network.plugin_config.enable_paged_kv_cache(args.tokens_per_block) + + with net_guard(network): + # Prepare + network.set_named_parameters(tensorrt_llm_llama.named_parameters()) + + # Forward + inputs = tensorrt_llm_llama.prepare_inputs( + args.max_batch_size, + args.max_input_len, + args.max_output_len, + True, + args.max_beam_width, + args.max_num_tokens, + ) + tensorrt_llm_llama(*inputs) + if args.enable_debug_output: + # mark intermediate nodes' outputs + for k, v in tensorrt_llm_llama.named_network_outputs(): + v = v.trt_tensor + v.name = k + network.trt_network.mark_output(v) + v.dtype = dtype + if args.visualize: + model_path = os.path.join(args.output_dir, "test.onnx") + to_onnx(network.trt_network, model_path) + + tensorrt_llm.graph_rewriting.optimize(network) + + engine = None + + # Network -> Engine + engine = builder.build_engine(network, builder_config) + if rank == 0: + config_path = os.path.join(args.output_dir, "config.json") + builder.save_config(builder_config, config_path) + return engine + + +def build(rank, args): + torch.cuda.set_device(rank % args.gpus_per_node) + logger.set_level(args.log_level) + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + + # when doing serializing build, all ranks share one engine + builder = Builder() + + cache = None + for cur_rank in range(args.world_size): + # skip other ranks if parallel_build is enabled + if args.parallel_build and cur_rank != rank: + continue + # NOTE: when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT + int8_trt_flag = args.quant_mode.has_act_and_weight_quant() or ( + not args.paged_kv_cache and args.quant_mode.has_int8_kv_cache() + ) + builder_config = builder.create_builder_config( + name=MODEL_NAME, + precision=args.dtype, + timing_cache=args.timing_cache if cache is None else cache, + tensor_parallel=args.tp_size, + pipeline_parallel=args.pp_size, + parallel_build=args.parallel_build, + num_layers=args.n_layer, + num_heads=args.n_head, + num_kv_heads=args.n_kv_head, + hidden_size=args.n_embd, + vocab_size=args.vocab_size, + hidden_act=args.hidden_act, + max_position_embeddings=args.n_positions, + max_batch_size=args.max_batch_size, + max_input_len=args.max_input_len, + max_output_len=args.max_output_len, + max_num_tokens=args.max_num_tokens, + int8=int8_trt_flag, + fp8=args.quant_mode.has_fp8_qdq(), + quant_mode=args.quant_mode, + strongly_typed=args.strongly_typed, + opt_level=args.builder_opt, + ) + engine_name = get_engine_name( + MODEL_NAME, args.dtype, args.tp_size, args.pp_size, cur_rank + ) + engine = build_rank_engine(builder, builder_config, engine_name, cur_rank, args) + assert engine is not None, f"Failed to build engine for rank {cur_rank}" + + if cur_rank == 0: + # Use in-memory timing cache for multiple builder passes. + if not args.parallel_build: + cache = builder_config.trt_builder_config.get_timing_cache() + + serialize_engine(engine, os.path.join(args.output_dir, engine_name)) + + if rank == 0: + ok = builder.save_timing_cache( + builder_config, os.path.join(args.output_dir, "model.cache") + ) + assert ok, "Failed to save timing cache." + + +if __name__ == "__main__": + args = parse_arguments() + tik = time.time() + if ( + args.parallel_build + and args.world_size > 1 + and torch.cuda.device_count() >= args.world_size + ): + logger.warning( + f"Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free." + ) + mp.spawn(build, nprocs=args.world_size, args=(args,)) + else: + args.parallel_build = False + logger.info("Serially build TensorRT engines.") + build(0, args) + + tok = time.time() + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + logger.info(f"Total time of building all {args.world_size} engines: {t}") diff --git a/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/weight.py b/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/weight.py new file mode 100644 index 00000000..692ae67f --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/weight.py @@ -0,0 +1,1446 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import configparser +import time +from operator import attrgetter +from pathlib import Path +from typing import Dict, List, Optional, Union + +import numpy as np +import tensorrt_llm +import tensorrt_llm.logger as logger +import torch +from safetensors import safe_open +from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy +from tensorrt_llm.mapping import Mapping +from tensorrt_llm.models import LLaMAForCausalLM +from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales +from tensorrt_llm.quantization import QuantMode + + +def get_scaling_factors( + model_path: Union[str, Path], + num_layers: int, + quant_mode: Optional[QuantMode] = None, +) -> Optional[Dict[str, List[int]]]: + """Get the scaling factors for LLaMA model + + Returns a dictionary of scaling factors for the selected layers of the + LLaMA model. + + Args: + model_path (str): Path to the quantized LLaMA model + layers (list): List of layers to get the scaling factors for. If None, + all layers are selected. + + Returns: + dict: Dictionary of scaling factors for the selected layers of the + LLaMA model. + + example: + + { + 'qkv_act': qkv_act_scale, + 'qkv_weights': qkv_weights_scale, + 'qkv_output' : qkv_outputs_scale, + 'dense_act': dense_act_scale, + 'dense_weights': dense_weights_scale, + 'fc_act': fc_act_scale, + 'fc_weights': fc_weights_scale, + 'gate_act': gate_act_scale, + 'gate_weights': gate_weights_scale, + 'proj_act': proj_act_scale, + 'proj_weights': proj_weights_scale, + } + """ + + if model_path is None: + logger.warning( + f"--quantized_fp8_model_path not specified. " + f"Initialize quantization scales automatically." + ) + return get_dummy_quant_scales(num_layers) + weight_dict = np.load(model_path) + + # yapf: disable + scaling_factor = { + 'qkv_act': [], + 'qkv_weights': [], + 'qkv_output': [], + 'dense_act': [], + 'dense_weights': [], + 'fc_act': [], + 'fc_weights': [], + 'gate_act': [], + 'gate_weights': [], + 'proj_act': [], + 'proj_weights': [], + } + + for layer in range(num_layers): + scaling_factor['qkv_act'].append(max( + weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(), + weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(), + weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item() + )) + scaling_factor['qkv_weights'].append(max( + weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(), + weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(), + weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item() + )) + if quant_mode is not None and quant_mode.has_fp8_kv_cache(): + # Not calibrarting KV cache. + scaling_factor['qkv_output'].append(1.0) + scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item()) + scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item()) + scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item()) + scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item()) + scaling_factor['gate_act'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:activation_scaling_factor'].item()) + scaling_factor['gate_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:weights_scaling_factor'].item()) + scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item()) + scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item()) + # yapf: enable + for k, v in scaling_factor.items(): + assert ( + len(v) == num_layers + ), f"Expect scaling factor {k} of length {num_layers}, got {len(v)}" + + return scaling_factor + + +def gen_suffix(rank, use_smooth_quant, quant_per_channel): + suffix = f"{rank}.bin" + if use_smooth_quant: + sq_prefix = "int8." + if quant_per_channel: + sq_prefix += "col." + suffix = sq_prefix + suffix + return suffix + + +def extract_layer_idx(name): + ss = name.split(".") + for s in ss: + if s.isdigit(): + return s + return None + + +def split(v, tp_size, idx, dim=0): + if tp_size == 1: + return v + if len(v.shape) == 1: + return np.ascontiguousarray(np.split(v, tp_size)[idx]) + else: + return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx]) + + +def dup_kv_weight(v, num_head, tp_size): + assert tp_size % num_head == 0 + reps = tp_size // num_head + head_size = v.shape[0] // num_head + v = v.reshape(num_head, head_size, -1)[:, None, :, :].expand( + num_head, reps, head_size, v.shape[1] + ) + return v.reshape(num_head * reps * head_size, -1).clone() + + +def parse_ft_config(ini_file): + gpt_config = configparser.ConfigParser() + gpt_config.read(ini_file) + + n_embd = gpt_config.getint("llama", "hidden_size") + n_head = gpt_config.getint("llama", "num_attention_heads") + n_layer = gpt_config.getint("llama", "num_hidden_layers") + n_positions = gpt_config.getint("llama", "max_position_embeddings") + vocab_size = gpt_config.getint("llama", "vocab_size") + hidden_act = gpt_config.get("llama", "hidden_act") + inter_size = gpt_config.getint("llama", "intermediate_size", fallback=None) + n_kv_head = gpt_config.getint("llama", "num_key_value_heads", fallback=None) + + if inter_size is None: + inter_size = 4 * n_embd + + return ( + n_embd, + n_head, + n_layer, + n_positions, + vocab_size, + hidden_act, + inter_size, + n_kv_head, + ) + + +def load_from_hf_llama( + tensorrt_llm_llama: tensorrt_llm.models.LLaMAForCausalLM, + hf_llama, + mapping=Mapping(), + dtype="float32", +): + tensorrt_llm.logger.info("Loading weights from HF LLaMA...") + tik = time.time() + + quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0)) + if quant_mode.is_int8_weight_only(): + plugin_weight_only_quant_type = torch.int8 + elif quant_mode.is_int4_weight_only(): + plugin_weight_only_quant_type = torch.quint4x2 + use_weight_only = quant_mode.is_weight_only() + num_kv_heads = tensorrt_llm_llama.num_kv_heads + mha_mode = num_kv_heads == tensorrt_llm_llama.num_heads + + model_params = dict(hf_llama.named_parameters()) + for l in range(hf_llama.config.num_hidden_layers): + prefix = f"model.layers.{l}.self_attn." + q_weight = model_params[prefix + "q_proj.weight"] + k_weight = model_params[prefix + "k_proj.weight"] + v_weight = model_params[prefix + "v_proj.weight"] + if not mha_mode: + head_size = tensorrt_llm_llama.hidden_size // tensorrt_llm_llama.num_heads + if num_kv_heads < mapping.tp_size: + # duplicate the KV heads up to tensor_parallel + k_weight = dup_kv_weight(k_weight, num_kv_heads, mapping.tp_size) + v_weight = dup_kv_weight(v_weight, num_kv_heads, mapping.tp_size) + assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0 + assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0 + qkv_weight = [q_weight, k_weight, v_weight] + else: + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) + + model_params[prefix + "qkv_proj.weight"] = qkv_weight + + torch_dtype = str_dtype_to_torch(dtype) + layers_per_pipeline_stage = hf_llama.config.num_hidden_layers // mapping.pp_size + layers_range = list( + range( + mapping.pp_rank * layers_per_pipeline_stage, + (mapping.pp_rank + 1) * layers_per_pipeline_stage, + 1, + ) + ) + for k, v in model_params.items(): + if isinstance(v, list): + v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v] + else: + v = torch_to_numpy(v.to(torch_dtype).detach().cpu()) + if "model.embed_tokens.weight" in k: + if tensorrt_llm_llama.use_parallel_embedding: + v = split( + v, + mapping.tp_size, + mapping.tp_rank, + tensorrt_llm_llama.embedding_sharding_dim, + ) + if mapping.is_first_pp_rank(): + tensorrt_llm_llama.vocab_embedding.weight.value = v + elif "model.norm.weight" in k: + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.ln_f.weight.value = v + elif "lm_head.weight" in k: + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.lm_head.weight.value = np.ascontiguousarray( + split(v, mapping.tp_size, mapping.tp_rank) + ) + else: + layer_idx = extract_layer_idx(k) + if layer_idx is None or int(layer_idx) not in layers_range: + continue + idx = int(layer_idx) - mapping.pp_rank * layers_per_pipeline_stage + if idx >= tensorrt_llm_llama.num_layers: + continue + if "input_layernorm.weight" in k: + tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = v + elif "post_attention_layernorm.weight" in k: + dst = tensorrt_llm_llama.layers[idx].post_layernorm.weight + dst.value = v + elif "self_attn.qkv_proj.weight" in k: + dst = tensorrt_llm_llama.layers[idx].attention.qkv.weight + if not mha_mode: + assert isinstance(v, list) and len(v) == 3 + wq = split(v[0], mapping.tp_size, mapping.tp_rank) + wk = split(v[1], mapping.tp_size, mapping.tp_rank) + wv = split(v[2], mapping.tp_size, mapping.tp_rank) + split_v = np.concatenate((wq, wk, wv)) + else: + q_emb = v.shape[0] // 3 + model_emb = v.shape[1] + v = v.reshape(3, q_emb, model_emb) + split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1) + split_v = split_v.reshape(3 * (q_emb // mapping.tp_size), model_emb) + if use_weight_only: + v = np.ascontiguousarray(split_v.transpose()) + ( + processed_torch_weights, + torch_weight_scales, + ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(v), plugin_weight_only_quant_type + ) + # workaround for trt not supporting int8 inputs in plugins currently + dst.value = processed_torch_weights.view( + dtype=torch.float32 + ).numpy() + scales = tensorrt_llm_llama.layers[ + idx + ].attention.qkv.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + dst.value = np.ascontiguousarray(split_v) + elif "self_attn.o_proj.weight" in k: + dst = tensorrt_llm_llama.layers[idx].attention.dense.weight + split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1) + if use_weight_only: + v = np.ascontiguousarray(split_v.transpose()) + ( + processed_torch_weights, + torch_weight_scales, + ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(v), plugin_weight_only_quant_type + ) + # workaround for trt not supporting int8 inputs in plugins currently + dst.value = processed_torch_weights.view( + dtype=torch.float32 + ).numpy() + scales = tensorrt_llm_llama.layers[ + idx + ].attention.dense.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + dst.value = np.ascontiguousarray(split_v) + elif "mlp.up_proj.weight" in k: + dst = tensorrt_llm_llama.layers[idx].mlp.gate.weight + split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0) + if use_weight_only: + v = np.ascontiguousarray(split_v.transpose()) + ( + processed_torch_weights, + torch_weight_scales, + ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(v), plugin_weight_only_quant_type + ) + # workaround for trt not supporting int8 inputs in plugins currently + dst.value = processed_torch_weights.view( + dtype=torch.float32 + ).numpy() + scales = tensorrt_llm_llama.layers[idx].mlp.gate.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + dst.value = np.ascontiguousarray(split_v) + elif "mlp.down_proj.weight" in k: + dst = tensorrt_llm_llama.layers[idx].mlp.proj.weight + split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1) + if use_weight_only: + v = np.ascontiguousarray(split_v.transpose()) + ( + processed_torch_weights, + torch_weight_scales, + ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(v), plugin_weight_only_quant_type + ) + # workaround for trt not supporting int8 inputs in plugins currently + dst.value = processed_torch_weights.view( + dtype=torch.float32 + ).numpy() + scales = tensorrt_llm_llama.layers[idx].mlp.proj.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + dst.value = np.ascontiguousarray(split_v) + elif "mlp.gate_proj.weight" in k: + dst = tensorrt_llm_llama.layers[idx].mlp.fc.weight + split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0) + if use_weight_only: + v = np.ascontiguousarray(split_v.transpose()) + ( + processed_torch_weights, + torch_weight_scales, + ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(v), plugin_weight_only_quant_type + ) + # workaround for trt not supporting int8 inputs in plugins currently + dst.value = processed_torch_weights.view( + dtype=torch.float32 + ).numpy() + scales = tensorrt_llm_llama.layers[idx].mlp.fc.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + dst.value = np.ascontiguousarray(split_v) + + tok = time.time() + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}") + return + + +def load_from_meta_llama( + tensorrt_llm_llama: tensorrt_llm.models.LLaMAForCausalLM, + meta_ckpt_dir, + mapping=Mapping(), + dtype="float32", +): + torch_dtype = str_dtype_to_torch(dtype) + + def gather_ckpts(ckpts): + gathered = {} + for k in ckpts[0]: + d = 0 + if any([n in k for n in ["wo", "w2", "tok"]]): + d = 1 + if "norm" in k or "rope" in k: # no TP + gathered[k] = ckpts[0][k].clone() + else: + gathered[k] = torch.cat([pt[k] for pt in ckpts], dim=d).clone() + return gathered + + def split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank): + split_ckpt = {} + for k in ckpt: + d = 0 + if any([n in k for n in ["wo", "w2", "tok"]]): + d = 1 + if "norm" in k or "rope" in k: # no TP + split_ckpt[k] = ckpt[k].clone() + elif tensorrt_llm_llama.num_kv_heads < mapping.tp_size and any( + [n in k for n in ["wk", "wv"]] + ): + assert mapping.tp_size % tensorrt_llm_llama.num_kv_heads == 0 + # special case: we need to duplicate KV head + tmp = dup_kv_weight( + ckpt[k], tensorrt_llm_llama.num_kv_heads, mapping.tp_size + ) + split_ckpt[k] = torch.split(tmp, tmp.shape[d] // ranks_per_ckpt, dim=d)[ + ckpt_rank + ].clone() + else: + split_ckpt[k] = torch.split( + ckpt[k], ckpt[k].shape[d] // ranks_per_ckpt, dim=d + )[ckpt_rank].clone() + return split_ckpt + + def get_current_weights(num_ckpts): + if num_ckpts > mapping.tp_size: + # combine ckpts + assert (num_ckpts % mapping.tp_size) == 0 + nf = num_ckpts // mapping.tp_size + fs = nf * mapping.tp_rank + file_ids = list(range(fs, fs + nf)) + ckpts = [] + for f in file_ids: + ckpt = torch.load( + Path(meta_ckpt_dir, f"consolidated.{f:02d}.pth"), map_location="cpu" + ) + ckpts.append(ckpt) + return gather_ckpts(ckpts) + elif num_ckpts < mapping.tp_size: + # split ckpt + assert (mapping.tp_size % num_ckpts) == 0 + ranks_per_ckpt = mapping.tp_size // num_ckpts + ckpt_fid = mapping.tp_rank // ranks_per_ckpt + ckpt_rank = mapping.tp_rank % ranks_per_ckpt + nH_per_ckpt = tensorrt_llm_llama.num_heads // num_ckpts + assert (nH_per_ckpt % ranks_per_ckpt) == 0 + ckpt = torch.load( + Path(meta_ckpt_dir, f"consolidated.{ckpt_fid:02d}.pth"), + map_location="cpu", + ) + return split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank) + + # num_ckpts == tensor_parallel, 1:1 mapping from files to TP + return torch.load( + Path(meta_ckpt_dir, f"consolidated.{mapping.tp_rank:02d}.pth"), + map_location="cpu", + ) + + def permute(w, nH, d, dH): + # due to MQA's wk, nH*dH != d could be true + return w.view(nH, dH // 2, 2, d).transpose(1, 2).reshape(nH * dH, d) + + if not hasattr(load_from_meta_llama, "saved_embed"): + load_from_meta_llama.saved_embed = None + + def gather_embedding(cur_embed, name: str, num_ckpts): + if mapping.tp_size == 1: + # even if num_ckpts > 1, get_current_weights will already have it gathered + return cur_embed + if load_from_meta_llama.saved_embed is None: + embeds = [None] * num_ckpts + for i in range(num_ckpts): + ckpt = torch.load( + Path(meta_ckpt_dir, f"consolidated.{i:02d}.pth"), map_location="cpu" + ) + embeds[i] = ckpt[name] + embed = torch.cat(embeds, dim=1).to(torch_dtype) + load_from_meta_llama.saved_embed = torch_to_numpy( + embed + ) # cache the embedding, not needed if no refit + return load_from_meta_llama.saved_embed + + tensorrt_llm.logger.info("Loading weights from Meta LLaMA checkpoints ...") + tik = time.time() + + quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0)) + if quant_mode.is_int8_weight_only(): + torch.int8 + elif quant_mode.is_int4_weight_only(): + torch.quint4x2 + quant_mode.is_weight_only() + num_kv_heads = tensorrt_llm_llama.num_kv_heads + mha_mode = num_kv_heads == tensorrt_llm_llama.num_heads + + ckpts = list(Path(meta_ckpt_dir).glob("consolidated.*.pth")) + num_ckpts = len(ckpts) + # llama/llama2 doesn't have MQA. So, simplifying loader logic by not worrying about it. + assert ( + num_kv_heads > 1 or num_kv_heads >= num_ckpts + ), f"We don't know how the {num_kv_heads} KV heads are distributed among {num_ckpts} checkpoints." + + head_size = tensorrt_llm_llama.hidden_size // tensorrt_llm_llama.num_heads + ckpt = get_current_weights(num_ckpts) + layers_range = list( + range( + mapping.pp_rank * tensorrt_llm_llama.num_layers, + (mapping.pp_rank + 1) * tensorrt_llm_llama.num_layers, + 1, + ) + ) + + for l in layers_range: + prefix = f"layers.{l}.attention." + q_weight = permute( + ckpt[prefix + "wq.weight"].clone(), + nH=(tensorrt_llm_llama.num_heads // mapping.tp_size), + d=tensorrt_llm_llama.hidden_size, + dH=head_size, + ) + if num_kv_heads < mapping.tp_size and num_ckpts >= mapping.tp_size: + assert mapping.tp_size % num_kv_heads == 0 + assert False, "Not supported yet" + k_weight = permute( + ckpt[prefix + "wk.weight"].clone(), + nH=((num_kv_heads + mapping.tp_size - 1) // mapping.tp_size), + d=tensorrt_llm_llama.hidden_size, + dH=head_size, + ) + v_weight = ckpt[prefix + "wv.weight"].clone() + + qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0) + ckpt[prefix + "qkv.weight"] = qkv_weight + + for k, v in ckpt.items(): + v = torch_to_numpy(v.to(torch_dtype).detach().cpu()) + if "tok_embeddings" in k: + if not tensorrt_llm_llama.use_parallel_embedding: + v = gather_embedding(v, k, num_ckpts) + elif tensorrt_llm_llama.embedding_sharding_dim == 0: + # this needs a gather and then resplit along different dims + v = gather_embedding(v, k, num_ckpts) + v = split(v, mapping.tp_size, mapping.tp_rank, 0) + if mapping.is_first_pp_rank(): + tensorrt_llm_llama.vocab_embedding.weight.value = v + elif "output" in k: + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.lm_head.weight.value = v + elif k == "norm.weight": + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.ln_f.weight.value = v + else: + # layer specific weights + layer_idx = extract_layer_idx(k) + if layer_idx is None: + continue + idx = int(layer_idx) - mapping.pp_rank * tensorrt_llm_llama.num_layers + if idx >= tensorrt_llm_llama.num_layers: + continue + if "attention_norm.weight" in k: + tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = v + elif "ffn_norm.weight" in k: + tensorrt_llm_llama.layers[idx].post_layernorm.weight.value = v + elif "feed_forward.w3.weight" in k: + tensorrt_llm_llama.layers[idx].mlp.gate.weight.value = v + elif "feed_forward.w2.weight" in k: + tensorrt_llm_llama.layers[idx].mlp.proj.weight.value = v + elif "feed_forward.w1.weight" in k: + tensorrt_llm_llama.layers[idx].mlp.fc.weight.value = v + elif "attention.wo.weight" in k: + tensorrt_llm_llama.layers[idx].attention.dense.weight.value = v + elif "attention.qkv.weight" in k: + tensorrt_llm_llama.layers[idx].attention.qkv.weight.value = v + + tok = time.time() + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}") + return + + +def load_from_binary( + tensorrt_llm_llama: LLaMAForCausalLM, + dir_path, + mapping=Mapping(), + fp16=False, + multi_query_mode=False, +): + tensorrt_llm.logger.info("Loading weights from FT...") + tik = time.time() + + quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0)) + + ( + n_embd, + n_head, + n_layer, + n_positions, + vocab_size, + hidden_act, + inter_size, + n_kv_head, + ) = parse_ft_config(Path(dir_path) / "config.ini") + np_dtype = np.float16 if fp16 else np.float32 + + def fromfile(dir_path, name, shape=None, dtype=None): + dtype = np_dtype if dtype is None else dtype + p = dir_path + "/" + name + if Path(p).exists(): + t = np.fromfile(p, dtype=dtype) + if shape is not None: + t = t.reshape(shape) + return t + return None + + def set_smoothquant_scale_factors( + module, + pre_scale_weight, + dir_path, + basename, + shape, + per_tok_dyn, + per_channel, + is_qkv=False, + rank=None, + ): + suffix = "bin" + if per_channel: + if rank is not None: + suffix = f"{rank}." + suffix + suffix = "col." + suffix + + col_shape = shape if (per_channel or is_qkv) else [1, 1] + + if per_tok_dyn: + if pre_scale_weight is not None: + pre_scale_weight.value = np.array([1.0], dtype=np.float32) + if is_qkv and not per_channel: + t = fromfile( + dir_path, + f"{basename}scale_w_quant_orig.{rank}.{suffix}", + col_shape, + np.float32, + ) + else: + t = fromfile( + dir_path, + f"{basename}scale_w_quant_orig.{suffix}", + col_shape, + np.float32, + ) + module.per_channel_scale.value = t + else: + t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1], np.float32) + pre_scale_weight.value = t + if is_qkv: + t = fromfile( + dir_path, + f"{basename}scale_y_accum_quant.{rank}.{suffix}", + col_shape, + np.float32, + ) + else: + t = fromfile( + dir_path, + f"{basename}scale_y_accum_quant.{suffix}", + col_shape, + np.float32, + ) + module.per_channel_scale.value = t + t = fromfile( + dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1], np.float32 + ) + module.act_scale.value = t + + def set_smoother(module, dir_path, base_name, shape, rank): + suffix = f"{rank}.bin" + t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape, np.float32) + module.smoother.value = t + + # Determine the quantization mode. + quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0)) + if quant_mode.is_int8_weight_only(): + plugin_weight_only_quant_type = torch.int8 + elif quant_mode.is_int4_weight_only(): + plugin_weight_only_quant_type = torch.quint4x2 + # Do we use SmoothQuant? + use_smooth_quant = quant_mode.has_act_and_weight_quant() + # Do we use quantization per token? + quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling() + # Do we use quantization per channel? + quant_per_channel = quant_mode.has_per_channel_scaling() + + # Do we use INT4/INT8 weight-only? + use_weight_only = quant_mode.is_weight_only() + + # Int8 KV cache + use_int8_kv_cache = quant_mode.has_int8_kv_cache() + + def sq_trick(x): + return x.view(np.float32) if use_smooth_quant else x + + # Debug + suffix = gen_suffix(mapping.tp_rank, use_smooth_quant, quant_per_channel) + # The type of weights. + w_type = np_dtype if not use_smooth_quant else np.int8 + + if mapping.is_first_pp_rank(): + tensorrt_llm_llama.vocab_embedding.weight.value = fromfile( + dir_path, "vocab_embedding.weight.bin", [vocab_size, n_embd] + ) + + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.ln_f.weight.value = fromfile(dir_path, "ln_f.weight.bin") + # share input embedding + lm_head_weight = fromfile(dir_path, "lm_head.weight.bin", [vocab_size, n_embd]) + + if vocab_size % mapping.tp_size != 0: + # padding + vocab_size_padded = tensorrt_llm_llama.lm_head.out_features * mapping.tp_size + pad_width = vocab_size_padded - vocab_size + lm_head_weight = np.pad( + lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0 + ) + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.lm_head.weight.value = np.ascontiguousarray( + split(lm_head_weight, mapping.tp_size, mapping.tp_rank) + ) + + layers_range = list( + range( + mapping.pp_rank * tensorrt_llm_llama.num_layers, + (mapping.pp_rank + 1) * tensorrt_llm_llama.num_layers, + 1, + ) + ) + + for i in layers_range: + n_groups = n_head // n_kv_head + c_attn_out_dim = ( + (3 * n_embd // mapping.tp_size) + if not multi_query_mode + else ( + n_embd // mapping.tp_size + + (n_embd // n_head * n_groups) // mapping.tp_size * 2 + ) + ) + idx = i - mapping.pp_rank * tensorrt_llm_llama.num_layers + tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = fromfile( + dir_path, "model.layers." + str(i) + ".input_layernorm.weight.bin" + ) + t = fromfile( + dir_path, + "model.layers." + str(i) + ".attention.query_key_value.weight." + suffix, + [n_embd, c_attn_out_dim], + w_type, + ) + if t is not None: + dst = tensorrt_llm_llama.layers[idx].attention.qkv.weight + if use_smooth_quant: + dst.value = sq_trick(np.ascontiguousarray(np.transpose(t, [1, 0]))) + set_smoothquant_scale_factors( + tensorrt_llm_llama.layers[idx].attention.qkv, + tensorrt_llm_llama.layers[idx].input_layernorm.scale_to_int, + dir_path, + "model.layers." + str(i) + ".attention.query_key_value.", + [1, c_attn_out_dim], + quant_per_token_dyn, + quant_per_channel, + rank=mapping.tp_rank, + is_qkv=True, + ) + elif use_weight_only: + ( + processed_torch_weights, + torch_weight_scales, + ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(t), plugin_weight_only_quant_type + ) + # workaround for trt not supporting int8 inputs in plugins currently + dst.value = processed_torch_weights.view(dtype=torch.float32).numpy() + scales = tensorrt_llm_llama.layers[i].attention.qkv.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) + + dst = tensorrt_llm_llama.layers[idx].attention.dense.weight + t = fromfile( + dir_path, + "model.layers." + str(i) + ".attention.dense.weight." + suffix, + [n_embd // mapping.tp_size, n_embd], + w_type, + ) + if use_smooth_quant: + dst.value = sq_trick(np.ascontiguousarray(np.transpose(t, [1, 0]))) + dense_scale = getattr( + tensorrt_llm_llama.layers[idx].attention, + "quantization_scaling_factor", + None, + ) + set_smoothquant_scale_factors( + tensorrt_llm_llama.layers[idx].attention.dense, + dense_scale, + dir_path, + "model.layers." + str(i) + ".attention.dense.", + [1, n_embd], + quant_per_token_dyn, + quant_per_channel, + ) + set_smoother( + tensorrt_llm_llama.layers[idx].attention.dense, + dir_path, + "model.layers." + str(i) + ".attention.dense", + [1, n_embd // mapping.tp_size], + mapping.tp_rank, + ) + elif use_weight_only: + ( + processed_torch_weights, + torch_weight_scales, + ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(t), plugin_weight_only_quant_type + ) + # workaround for trt not supporting int8 inputs in plugins currently + dst.value = processed_torch_weights.view(dtype=torch.float32).numpy() + scales = tensorrt_llm_llama.layers[i].attention.dense.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + dst.value = np.ascontiguousarray(np.transpose(t, [1, 0])) + + dst = tensorrt_llm_llama.layers[idx].post_layernorm.weight + dst.value = fromfile( + dir_path, "model.layers." + str(i) + ".post_layernorm.weight.bin" + ) + + t = fromfile( + dir_path, + "model.layers." + str(i) + ".mlp.fc.weight." + suffix, + [n_embd, inter_size // mapping.tp_size], + w_type, + ) + + if use_smooth_quant: + tensorrt_llm_llama.layers[idx].mlp.fc.weight.value = sq_trick( + np.ascontiguousarray(np.transpose(t, [1, 0])) + ) + set_smoothquant_scale_factors( + tensorrt_llm_llama.layers[idx].mlp.fc, + tensorrt_llm_llama.layers[idx].post_layernorm.scale_to_int, + dir_path, + "model.layers." + str(i) + ".mlp.fc.", + [1, inter_size // mapping.tp_size], + quant_per_token_dyn, + quant_per_channel, + rank=mapping.tp_rank, + ) + elif use_weight_only: + dst = tensorrt_llm_llama.layers[i].mlp.fc.weight + ( + processed_torch_weights, + torch_weight_scales, + ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(t), plugin_weight_only_quant_type + ) + # workaround for trt not supporting int8 inputs in plugins currently + dst.value = processed_torch_weights.view(dtype=torch.float32).numpy() + scales = tensorrt_llm_llama.layers[i].mlp.fc.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + tensorrt_llm_llama.layers[idx].mlp.fc.weight.value = np.ascontiguousarray( + np.transpose(t, [1, 0]) + ) + + t = fromfile( + dir_path, + "model.layers." + str(i) + ".mlp.gate.weight." + suffix, + [n_embd, inter_size // mapping.tp_size], + w_type, + ) + if use_smooth_quant: + tensorrt_llm_llama.layers[idx].mlp.gate.weight.value = sq_trick( + np.ascontiguousarray(np.transpose(t, [1, 0])) + ) + set_smoothquant_scale_factors( + tensorrt_llm_llama.layers[idx].mlp.gate, + tensorrt_llm_llama.layers[idx].post_layernorm.scale_to_int, + dir_path, + "model.layers." + str(i) + ".mlp.gate.", + [1, inter_size // mapping.tp_size], + quant_per_token_dyn, + quant_per_channel, + rank=mapping.tp_rank, + ) + elif use_weight_only: + dst = tensorrt_llm_llama.layers[i].mlp.gate.weight + ( + processed_torch_weights, + torch_weight_scales, + ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(t), plugin_weight_only_quant_type + ) + # workaround for trt not supporting int8 inputs in plugins currently + dst.value = processed_torch_weights.view(dtype=torch.float32).numpy() + scales = tensorrt_llm_llama.layers[i].mlp.gate.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + tensorrt_llm_llama.layers[idx].mlp.gate.weight.value = np.ascontiguousarray( + np.transpose(t, [1, 0]) + ) + + t = fromfile( + dir_path, + "model.layers." + str(i) + ".mlp.proj.weight." + suffix, + [inter_size // mapping.tp_size, n_embd], + w_type, + ) + if use_smooth_quant: + tensorrt_llm_llama.layers[idx].mlp.proj.weight.value = sq_trick( + np.ascontiguousarray(np.transpose(t, [1, 0])) + ) + proj_scale = getattr( + tensorrt_llm_llama.layers[idx].mlp, "quantization_scaling_factor", None + ) + set_smoothquant_scale_factors( + tensorrt_llm_llama.layers[idx].mlp.proj, + proj_scale, + dir_path, + "model.layers." + str(i) + ".mlp.proj.", + [1, n_embd], + quant_per_token_dyn, + quant_per_channel, + ) + set_smoother( + tensorrt_llm_llama.layers[idx].mlp.proj, + dir_path, + "model.layers." + str(i) + ".mlp.proj", + [1, inter_size // mapping.tp_size], + mapping.tp_rank, + ) + elif use_weight_only: + dst = tensorrt_llm_llama.layers[i].mlp.proj.weight + ( + processed_torch_weights, + torch_weight_scales, + ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix( + torch.tensor(t), plugin_weight_only_quant_type + ) + # workaround for trt not supporting int8 inputs in plugins currently + dst.value = processed_torch_weights.view(dtype=torch.float32).numpy() + scales = tensorrt_llm_llama.layers[i].mlp.proj.per_channel_scale + scales.value = torch_weight_scales.numpy() + else: + tensorrt_llm_llama.layers[idx].mlp.proj.weight.value = np.ascontiguousarray( + np.transpose(t, [1, 0]) + ) + + if use_int8_kv_cache: + t = fromfile( + dir_path, + "model.layers." + + str(i) + + ".attention.query_key_value.scale_y_quant_orig.bin", + [1], + np.float32, + ) + tensorrt_llm_llama.layers[idx].attention.kv_orig_quant_scale.value = 1.0 / t + tensorrt_llm_llama.layers[idx].attention.kv_quant_orig_scale.value = t + + tok = time.time() + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}") + + +def load_from_gptq_llama( + tensorrt_llm_llama, quant_ckpt_path, mapping=Mapping(), dtype="float16" +): + tensorrt_llm.logger.info("Loading weights from groupwise GPTQ LLaMA safetensors...") + tik = time.time() + + if quant_ckpt_path.endswith(".safetensors"): + groupwise_qweight_safetensors = safe_open( + quant_ckpt_path, framework="pt", device=0 + ) + model_params = { + key: groupwise_qweight_safetensors.get_tensor(key) + for key in groupwise_qweight_safetensors.keys() + } + elif quant_ckpt_path.endswith(".pt"): + model_params = torch.load(quant_ckpt_path, map_location=torch.device("cpu")) + else: + assert False, "Quantized checkpoint format not supported!" + + def unpack_int32_into_int8(w_packed): + # Unpack inputs packed in int32/float32 into uint4 and store them in int8 format + w_packed_int4x2 = w_packed.contiguous().view(torch.uint8) + w_unpacked = torch.zeros( + w_packed_int4x2.shape[0], w_packed_int4x2.shape[1] * 2, dtype=torch.int8 + ) + w_unpacked[:, ::2] = w_packed_int4x2 % 16 + w_unpacked[:, 1::2] = w_packed_int4x2 // 16 + return w_unpacked.contiguous() + + def preprocess_groupwise_weight_params( + weight_name, qweight_int32=None, qzeros_int32=None, scales_fp16=None + ): + if weight_name is not None: + qweight_int32 = model_params[weight_name].cpu() + qzeros_int32 = model_params[weight_name[:-7] + "qzeros"].cpu() + scales_fp16 = model_params[weight_name[:-7] + "scales"].cpu() + + UINT4_TO_INT4_FLAG = 1 + GPTQ_FLAG = 1 + packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4 + preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm + + qweight_unpacked_int8 = ( + unpack_int32_into_int8(qweight_int32.T).T.contiguous() - 8 + ) + qweight_interleaved = preprocessor( + packer(qweight_unpacked_int8), torch.quint4x2 + ).view(torch.float32) + # zeros = zeros * scales + qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32) + zeros_x_scales_fp16 = ( + -qzeros_unpacked_int32 + 8 * UINT4_TO_INT4_FLAG - GPTQ_FLAG + ) * scales_fp16 + zeros_x_scales_fp16 = zeros_x_scales_fp16.half() + + # return processed interleaved weight, original scales and zeros * scales + return ( + qweight_interleaved.contiguous(), + scales_fp16.contiguous(), + zeros_x_scales_fp16.contiguous(), + ) + + layer_ids = [extract_layer_idx(key) for key in groupwise_qweight_safetensors.keys()] + layer_ids = [int(layer_idx) for layer_idx in layer_ids if layer_idx is not None] + num_hidden_layers = max(layer_ids) + 1 + num_kv_heads = tensorrt_llm_llama.num_kv_heads + mha_mode = num_kv_heads == tensorrt_llm_llama.num_heads + suffixs = ["qweight", "qzeros", "scales"] + + layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size + layers_range = list( + range( + mapping.pp_rank * layers_per_pipeline_stage, + (mapping.pp_rank + 1) * layers_per_pipeline_stage, + 1, + ) + ) + + for l in layers_range: + prefix = f"model.layers.{l}.self_attn." + split_qkv_suf = [] + + for suf in suffixs: + q_part = model_params[prefix + "q_proj." + suf].cpu() + k_part = model_params[prefix + "k_proj." + suf].cpu() + v_part = model_params[prefix + "v_proj." + suf].cpu() + qkv_part = torch.cat([q_part, k_part, v_part], dim=0) + dim = qkv_part.shape + qkv_part = qkv_part.reshape(3, dim[0] // 3, dim[1]) + split_qkv = qkv_part.split(dim[1] // mapping.tp_size, dim=2)[ + mapping.tp_rank + ] + split_qkv = torch.cat( + [ + split_qkv[0, :, :].squeeze(0), + split_qkv[1, :, :].squeeze(0), + split_qkv[2, :, :].squeeze(0), + ], + dim=1, + ) + split_qkv_suf.append(split_qkv) + + th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( + None, split_qkv_suf[0], split_qkv_suf[1], split_qkv_suf[2] + ) + + idx = l - mapping.pp_rank * layers_per_pipeline_stage + tensorrt_llm_llama.layers[idx].attention.qkv.qweight.value = th_qweight.numpy() + tensorrt_llm_llama.layers[idx].attention.qkv.scale.value = th_zero.numpy() + tensorrt_llm_llama.layers[idx].attention.qkv.zero.value = th_scale.numpy() + + torch_dtype = str_dtype_to_torch(dtype) + + for k, v in model_params.items(): + if isinstance(v, list): + v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v] + else: + v = torch_to_numpy(v.to(torch_dtype).detach().cpu()) + if "model.embed_tokens.weight" in k: + if mapping.is_first_pp_rank(): + tensorrt_llm_llama.vocab_embedding.weight.value = v + elif "model.norm.weight" in k: + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.ln_f.weight.value = v + elif "lm_head.weight" in k: + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.lm_head.weight.value = np.ascontiguousarray( + split(v, mapping.tp_size, mapping.tp_rank) + ) + else: + layer_idx = extract_layer_idx(k) + if layer_idx is None: + continue + idx = int(layer_idx) + if idx not in layers_range: + continue + idx = idx - mapping.pp_rank * layers_per_pipeline_stage + + if "input_layernorm.weight" in k: + tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = v + elif "post_attention_layernorm.weight" in k: + tensorrt_llm_llama.layers[idx].post_layernorm.weight.value = v + elif "self_attn.o_proj.qweight" in k: + split_v_suf = [] + for suf in suffixs: + v = model_params[k[:-7] + suf].cpu() + split_v = v.split(v.shape[0] // mapping.tp_size, dim=0)[ + mapping.tp_rank + ] + split_v_suf.append(split_v) + th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( + None, split_v_suf[0], split_v_suf[1], split_v_suf[2] + ) + tensorrt_llm_llama.layers[ + idx + ].attention.dense.qweight.value = th_qweight.numpy() + tensorrt_llm_llama.layers[ + idx + ].attention.dense.scale.value = th_zero.numpy() + tensorrt_llm_llama.layers[ + idx + ].attention.dense.zero.value = th_scale.numpy() + elif "mlp.up_proj.qweight" in k: + split_v_suf = [] + for suf in suffixs: + v = model_params[k[:-7] + suf].cpu() + split_v = v.split(v.shape[1] // mapping.tp_size, dim=1)[ + mapping.tp_rank + ] + split_v_suf.append(split_v) + th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( + None, split_v_suf[0], split_v_suf[1], split_v_suf[2] + ) + tensorrt_llm_llama.layers[ + idx + ].mlp.gate.qweight.value = th_qweight.numpy() + tensorrt_llm_llama.layers[idx].mlp.gate.scale.value = th_zero.numpy() + tensorrt_llm_llama.layers[idx].mlp.gate.zero.value = th_scale.numpy() + elif "mlp.down_proj.qweight" in k: + split_v_suf = [] + for suf in suffixs: + v = model_params[k[:-7] + suf].cpu() + split_v = v.split(v.shape[0] // mapping.tp_size, dim=0)[ + mapping.tp_rank + ] + split_v_suf.append(split_v) + th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( + None, split_v_suf[0], split_v_suf[1], split_v_suf[2] + ) + tensorrt_llm_llama.layers[ + idx + ].mlp.proj.qweight.value = th_qweight.numpy() + tensorrt_llm_llama.layers[idx].mlp.proj.scale.value = th_zero.numpy() + tensorrt_llm_llama.layers[idx].mlp.proj.zero.value = th_scale.numpy() + elif "mlp.gate_proj.qweight" in k: + split_v_suf = [] + for suf in suffixs: + v = model_params[k[:-7] + suf].cpu() + split_v = v.split(v.shape[1] // mapping.tp_size, dim=1)[ + mapping.tp_rank + ] + split_v_suf.append(split_v) + th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params( + None, split_v_suf[0], split_v_suf[1], split_v_suf[2] + ) + tensorrt_llm_llama.layers[idx].mlp.fc.qweight.value = th_qweight.numpy() + tensorrt_llm_llama.layers[idx].mlp.fc.scale.value = th_zero.numpy() + tensorrt_llm_llama.layers[idx].mlp.fc.zero.value = th_scale.numpy() + + tok = time.time() + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}") + return + + +def load_from_awq_llama( + tensorrt_llm_llama: LLaMAForCausalLM, + quant_ckpt_path, + mapping=Mapping(), + dtype="float16", +): + tensorrt_llm.logger.info("Loading weights from groupwise AWQ LLaMA safetensors...") + tik = time.time() + + if quant_ckpt_path.endswith(".safetensors"): + groupwise_qweight_safetensors = safe_open( + quant_ckpt_path, framework="pt", device=0 + ) + awq_llama = { + key: groupwise_qweight_safetensors.get_tensor(key) + for key in groupwise_qweight_safetensors.keys() + } + elif quant_ckpt_path.endswith(".pt"): + awq_llama = torch.load(quant_ckpt_path, map_location=torch.device("cpu")) + else: + assert False, "Quantized checkpoint format not supported!" + + group_size = ( + awq_llama["model.layers.0.self_attn.o_proj.weight"].numel() + // awq_llama["model.layers.0.self_attn.o_proj.weight_quantizer._amax"].numel() + ) + + awq_llama_block_names = [ + "input_layernorm.weight", + "post_attention_layernorm.weight", + ] + + tensorrt_llm_llama_block_names = [ + "input_layernorm.weight", + "post_layernorm.weight", + ] + + getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0)) + + packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4 + preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm + torch_dtype = str_dtype_to_torch(dtype) + + def AWQ_quantize_pack_preprocess(weight, scale): + scale = scale.repeat_interleave(group_size, dim=0) + weight = weight / scale + qweight_int8 = torch.clamp(torch.round(weight.cuda()).char(), -8, 7) + int4_weight = packer(qweight_int8.cpu()) + int4_weight = preprocessor(int4_weight, torch.quint4x2) + return int4_weight.view(torch.float32).cpu().numpy() + + def process_and_assign_weight(awq_llama, mPrefix, mOp, tp_dim=0): + weight = awq_llama[mPrefix + ".weight"].T.contiguous() + [k, n] = weight.shape + weight = weight.split(weight.shape[tp_dim] // mapping.tp_size, dim=tp_dim)[ + mapping.tp_rank + ] + amax = ( + awq_llama[mPrefix + ".weight_quantizer._amax"] + .reshape((n, int(k / group_size))) + .T.contiguous() + ) + amax = amax.split(amax.shape[tp_dim] // mapping.tp_size, dim=tp_dim)[ + mapping.tp_rank + ] + pre_quant_scale = awq_llama[ + mPrefix + ".input_quantizer._pre_quant_scale" + ].reshape((1, k)) + if tp_dim == 0: + pre_quant_scale = pre_quant_scale.split(k // mapping.tp_size, dim=1)[ + mapping.tp_rank + ] + scale = amax / 8.0 + mOp.qweight.value = AWQ_quantize_pack_preprocess(weight, scale) + mOp.scale.value = scale.to(torch_dtype).cpu().numpy() + mOp.pre_quant_scale.value = pre_quant_scale.to(torch_dtype).cpu().numpy() + + def deSmooth(weight, pre_quant_scale): + [k, n] = weight.shape + pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1, 0).contiguous() + weight = weight * pre_quant_scale + return weight + + def reSmooth(weight, pre_quant_scale): + [k, n] = weight.shape + pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1, 0).contiguous() + weight = weight / pre_quant_scale + return weight + + def get_scale(weight): + weight = weight.T.contiguous() + [n, k] = weight.shape + weight = weight.reshape(n, int(k / group_size), group_size) + weight = torch.abs(weight.reshape(-1, group_size)) + amax, idx = weight.max(1) + amax = amax.reshape(n, int(k / group_size)).T.contiguous() + return amax / 8 + + def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale): + weight = deSmooth(weight, pre_quant_scale) + weight = reSmooth(weight, avg_pre_quant_scale) + scale = get_scale(weight) + return weight, scale + + def process_and_assign_qkv_weight(awq_llama, prefix, mOp): + q_weight = awq_llama[prefix + "self_attn.q_proj.weight"].T.contiguous() + k_weight = awq_llama[prefix + "self_attn.k_proj.weight"].T.contiguous() + v_weight = awq_llama[prefix + "self_attn.v_proj.weight"].T.contiguous() + k = q_weight.shape[0] + + q_weight = q_weight.split(q_weight.shape[1] // mapping.tp_size, dim=1)[ + mapping.tp_rank + ] + k_weight = k_weight.split(k_weight.shape[1] // mapping.tp_size, dim=1)[ + mapping.tp_rank + ] + v_weight = v_weight.split(v_weight.shape[1] // mapping.tp_size, dim=1)[ + mapping.tp_rank + ] + + q_pre_quant_scale = awq_llama[ + prefix + "self_attn.q_proj.input_quantizer._pre_quant_scale" + ].reshape((1, k)) + k_pre_quant_scale = awq_llama[ + prefix + "self_attn.k_proj.input_quantizer._pre_quant_scale" + ].reshape((1, k)) + v_pre_quant_scale = awq_llama[ + prefix + "self_attn.v_proj.input_quantizer._pre_quant_scale" + ].reshape((1, k)) + + qkv_pre_quant_scale = ( + q_pre_quant_scale + k_pre_quant_scale + v_pre_quant_scale + ) / 3.0 + q_weight, q_scale = reSmooth_and_get_scale( + q_weight, q_pre_quant_scale, qkv_pre_quant_scale + ) + k_weight, k_scale = reSmooth_and_get_scale( + k_weight, k_pre_quant_scale, qkv_pre_quant_scale + ) + v_weight, v_scale = reSmooth_and_get_scale( + v_weight, v_pre_quant_scale, qkv_pre_quant_scale + ) + + qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1) + qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1) + + mOp.pre_quant_scale.value = qkv_pre_quant_scale.to(torch_dtype).cpu().numpy() + mOp.qweight.value = AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale) + mOp.scale.value = qkv_scale.to(torch_dtype).cpu().numpy() + + # Check if we need to pad vocab + v = awq_llama.get("model.embed_tokens.weight") + [vocab_size, k] = v.shape + pad_vocab = False + pad_vocab_size = vocab_size + if vocab_size % 64 != 0: + pad_vocab = True + pad_vocab_size = int((vocab_size + 63) / 64) * 64 + if pad_vocab: + new_v = torch.zeros([pad_vocab_size, k]) + new_v[:vocab_size, :] = v + v = new_v + if mapping.is_first_pp_rank(): + tensorrt_llm_llama.vocab_embedding.weight.value = ( + v.to(torch_dtype).cpu().numpy() + ) + + layer_ids = [extract_layer_idx(key) for key in awq_llama.keys()] + layer_ids = [int(layer_idx) for layer_idx in layer_ids if layer_idx is not None] + + num_hidden_layers = max(layer_ids) + 1 + layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size + layers_range = list( + range( + mapping.pp_rank * layers_per_pipeline_stage, + (mapping.pp_rank + 1) * layers_per_pipeline_stage, + 1, + ) + ) + + for layer_idx in layers_range: + prefix = "model.layers." + str(layer_idx) + "." + tensorrt_llm.logger.info(f"Process weights in layer: {layer_idx}") + for idx, awq_attr in enumerate(awq_llama_block_names): + v = awq_llama[prefix + awq_attr] + layer = attrgetter(tensorrt_llm_llama_block_names[idx])( + tensorrt_llm_llama.layers[layer_idx] + ) + setattr(layer, "value", v.to(torch_dtype).cpu().numpy()) + + # Attention QKV Linear + # concatenate the Q, K, V layers weights. + process_and_assign_qkv_weight( + awq_llama, prefix, tensorrt_llm_llama.layers[layer_idx].attention.qkv + ) + + # Attention Dense (out_proj) Linear + mPrefix = prefix + "self_attn.o_proj" + mOp = tensorrt_llm_llama.layers[layer_idx].attention.dense + process_and_assign_weight(awq_llama, mPrefix, mOp, 0) + + # MLP up_proj (mlp.gate) Linear + mPrefix = prefix + "mlp.up_proj" + mOp = tensorrt_llm_llama.layers[layer_idx].mlp.gate + process_and_assign_weight(awq_llama, mPrefix, mOp, 1) + + # MLP down_proj (mlp.proj) Linear + mPrefix = prefix + "mlp.down_proj" + mOp = tensorrt_llm_llama.layers[layer_idx].mlp.proj + process_and_assign_weight(awq_llama, mPrefix, mOp, 0) + + # MLP gate_proj (mlp.fc) Linear + mPrefix = prefix + "mlp.gate_proj" + mOp = tensorrt_llm_llama.layers[layer_idx].mlp.fc + process_and_assign_weight(awq_llama, mPrefix, mOp, 1) + + v = awq_llama["model.norm.weight"] + if mapping.is_last_pp_rank(): + tensorrt_llm_llama.ln_f.weight.value = v.to(torch_dtype).cpu().numpy() + + # lm_head + if pad_vocab: + weight = awq_llama["lm_head.weight"] + [vocab_size, k] = weight.shape + new_weight = torch.zeros([pad_vocab_size, k]) + new_weight[:vocab_size, :] = weight + new_weight = new_weight.T.contiguous() + amax = awq_llama["lm_head.weight_quantizer._amax"].reshape( + [vocab_size, k // group_size] + ) + new_amax = torch.ones([pad_vocab_size, k // group_size]) + new_amax[:vocab_size, :] = amax + new_amax = new_amax.T.contiguous() + new_scale = new_amax / 8 + tensorrt_llm_llama.lm_head.qweight.value = AWQ_quantize_pack_preprocess( + new_weight, new_scale + ) + tensorrt_llm_llama.lm_head.scale.value = new_scale.to(torch_dtype).cpu().numpy() + tensorrt_llm_llama.lm_head.pre_quant_scale.value = ( + awq_llama["lm_head.input_quantizer._pre_quant_scale"] + .to(torch_dtype) + .cpu() + .numpy() + ) + else: + mPrefix = "lm_head" + mOp = tensorrt_llm_llama.lm_head + if mapping.is_last_pp_rank(): + process_and_assign_weight(awq_llama, mPrefix, mOp, 1) + + tok = time.time() + t = time.strftime("%H:%M:%S", time.gmtime(tok - tik)) + tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}") diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/gptnext b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/gptnext new file mode 120000 index 00000000..056bf100 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/gptnext @@ -0,0 +1 @@ +llama \ No newline at end of file diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/1/.gitkeep b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/1/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/config.pbtxt b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/config.pbtxt new file mode 100755 index 00000000..cbd087ce --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/config.pbtxt @@ -0,0 +1,228 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "ensemble" +platform: "ensemble" +max_batch_size: 128 +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "max_tokens" + data_type: TYPE_UINT32 + dims: [ -1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "length_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + optional: true + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1, -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocessing" + model_version: -1 + input_map { + key: "QUERY" + value: "text_input" + } + input_map { + key: "REQUEST_OUTPUT_LEN" + value: "max_tokens" + } + output_map { + key: "REQUEST_INPUT_LEN" + value: "_REQUEST_INPUT_LEN" + } + output_map { + key: "INPUT_ID" + value: "_INPUT_ID" + } + output_map { + key: "REQUEST_OUTPUT_LEN" + value: "_REQUEST_OUTPUT_LEN" + } + }, + { + model_name: "tensorrt_llm" + model_version: -1 + input_map { + key: "input_ids" + value: "_INPUT_ID" + } + input_map { + key: "input_lengths" + value: "_REQUEST_INPUT_LEN" + } + input_map { + key: "request_output_len" + value: "_REQUEST_OUTPUT_LEN" + } + input_map { + key: "end_id" + value: "end_id" + } + input_map { + key: "pad_id" + value: "pad_id" + } + input_map { + key: "runtime_top_k" + value: "top_k" + } + input_map { + key: "runtime_top_p" + value: "top_p" + } + input_map { + key: "temperature" + value: "temperature" + } + input_map { + key: "len_penalty" + value: "length_penalty" + } + input_map { + key: "repetition_penalty" + value: "repetition_penalty" + } + input_map { + key: "min_length" + value: "min_length" + } + input_map { + key: "presence_penalty" + value: "presence_penalty" + } + input_map { + key: "random_seed" + value: "random_seed" + } + input_map { + key: "beam_width" + value: "beam_width" + } + input_map { + key: "streaming" + value: "stream" + } + output_map { + key: "output_ids" + value: "_TOKENS_BATCH" + } + }, + { + model_name: "postprocessing" + model_version: -1 + input_map { + key: "TOKENS_BATCH" + value: "_TOKENS_BATCH" + } + output_map { + key: "OUTPUT" + value: "text_output" + } + } + ] +} diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/1/model.py b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/1/model.py new file mode 100755 index 00000000..0e563c96 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/1/model.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import LlamaTokenizer + +TOKENIZER_DIR = os.environ.get("TOKENIZER_DIR", "/model") + +SPACE_CHAR = 9601 +NEWLINE_CHAR = 60 +STOP_TOKEN = 2 + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + self.model_config = model_config = json.loads(args["model_config"]) + + # Parse model output configs + output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") + + # Convert Triton types to numpy types + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + self.tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_DIR, legacy=False) + vocab = self.tokenizer.convert_ids_to_tokens( + list(range(self.tokenizer.vocab_size)) + ) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for request in requests: + # Get input tensors + tokens_batch = pb_utils.get_input_tensor_by_name( + request, "TOKENS_BATCH" + ).as_numpy() + + # Reshape Input + # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]]) + # tokens_batch = tokens_batch.T + + # Postprocessing output data. + outputs = self._postprocessing(tokens_batch) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + output_tensor = pb_utils.Tensor( + "OUTPUT", np.array(outputs).astype(self.output_dtype) + ) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[output_tensor] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + `Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + pb_utils.Logger.log("Finalizing the Post-Processing Model.") + + def _id_to_token(self, token_id): + # handle special tokens (end of string, unknown, etc) + try: + special_token_index = self.tokenizer.all_special_ids.index(token_id) + return self.tokenizer.all_special_tokens[special_token_index] + except ValueError: + pass + + # handle typical tokens + tokens = self.tokenizer.convert_ids_to_tokens(token_id) + if ord(tokens[0]) == SPACE_CHAR: + return f" {tokens[1:]}" + if ord(tokens[0]) == NEWLINE_CHAR: + return "\n" + return tokens + + def _postprocessing(self, tokens_batch): + tokens_batch = tokens_batch.tolist() + return [ + self._id_to_token(token_id) + for beam_tokens in tokens_batch + for token_ids in beam_tokens + for token_id in token_ids + ] + + # for beam_tokens in tokens_batch: + # for token_ids in beam_tokens: + # for token_id in token_ids: + # # handle special tokens (end of string, unknown, etc) + # special_token = self.tokenizer.added_tokens_decoder.get(token_id) + # if special_token: + # tokens = special_token.content + + # # handle typical tokens + # else: + # tokens = self.tokenizer.convert_ids_to_tokens(token_id) + # if ord(tokens[0]) == SPACE_CHAR: + # tokens = f" {tokens[1:]}" + # elif ord(tokens[0]) == NEWLINE_CHAR: + # tokens = "\n" + + # outputs.append(tokens) + # return outputs diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/config.pbtxt b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/config.pbtxt new file mode 100755 index 00000000..3c3ea10d --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/config.pbtxt @@ -0,0 +1,50 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "postprocessing" +backend: "python" +max_batch_size: 128 +input [ + { + name: "TOKENS_BATCH" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_STRING + dims: [ -1, -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/1/model.py b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/1/model.py new file mode 100644 index 00000000..44e8b9c4 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/1/model.py @@ -0,0 +1,244 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import csv +import json +import os + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils +from torch.nn.utils.rnn import pad_sequence +from transformers import LlamaTokenizer + +TOKENIZER_DIR = os.environ.get("TOKENIZER_DIR", "/model") + +END_ID = 2 + +# SYSTEM_PROMPT = ( +# """You are a helpful, respectful and honest assistant.""" +# """Always answer as helpfully as possible, while being safe.""" +# """Please ensure that your responses are positive in nature.""" +# ) + +# LLAMA_PROMPT_TEMPLATE = ( +# "[INST] <>" +# "{system_prompt}" +# "<>" +# "[/INST] {context} [INST] {question} [/INST]" +# ) + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + self.model_config = model_config = json.loads(args["model_config"]) + + # Parse model output configs and convert Triton types to numpy types + input_names = ["INPUT_ID", "REQUEST_INPUT_LEN"] + for input_name in input_names: + setattr( + self, + input_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name(model_config, input_name)[ + "data_type" + ] + ), + ) + + self.encoder = LlamaTokenizer.from_pretrained(TOKENIZER_DIR, legacy=False) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for request in requests: + # Get input tensors + query = pb_utils.get_input_tensor_by_name(request, "QUERY").as_numpy() + request_output_len = pb_utils.get_input_tensor_by_name( + request, "REQUEST_OUTPUT_LEN" + ).as_numpy() + + input_id, request_input_len = self._create_request(query) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + input_id_tensor = pb_utils.Tensor( + "INPUT_ID", np.array(input_id).astype(self.input_id_dtype) + ) + request_input_len_tensor = pb_utils.Tensor( + "REQUEST_INPUT_LEN", + np.array(request_input_len).astype(self.request_input_len_dtype), + ) + request_output_len_tensor = pb_utils.Tensor( + "REQUEST_OUTPUT_LEN", request_output_len + ) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[ + input_id_tensor, + request_input_len_tensor, + request_output_len_tensor, + ] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + pb_utils.Logger.log("Finalizing the Pre-Processing Model.") + + def _create_request(self, prompts): + """ + prompts : batch string (2D numpy array) + """ + + start_ids = [ + torch.IntTensor(self.encoder.encode(prompt[0].decode())) + for prompt in prompts + ] + + start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids]) + + start_ids = pad_sequence(start_ids, batch_first=True, padding_value=END_ID) + + return start_ids, start_lengths + + def _create_word_list(self, word_dict): + flat_ids = [] + offsets = [] + for word_dict_item in word_dict: + item_flat_ids = [] + item_offsets = [] + + words = list(csv.reader([word_dict_item[0].decode()]))[0] + for word in words: + ids = self._encode(word) + + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) + + return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) + + def to_word_list_format(self, word_dict): + flat_ids = [] + offsets = [] + for word_dict_item in word_dict: + item_flat_ids = [] + item_offsets = [] + + if isinstance(word_dict_item[0], bytes): + word_dict_item = [word_dict_item[0].decode()] + + words = list(csv.reader(word_dict_item))[0] + for word in words: + ids = self.encoder.encode(word) + + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) + + return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) + + def _encode(self, sentence): + sentence = sentence.decode() if isinstance(sentence, bytes) else sentence + return self.encoder.encode(sentence) diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/config.pbtxt b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/config.pbtxt new file mode 100644 index 00000000..d2e3029a --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/config.pbtxt @@ -0,0 +1,65 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "preprocessing" +backend: "python" +max_batch_size: 128 +input [ + { + name: "QUERY" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] +output [ + { + name: "INPUT_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_UINT32 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/1/.gitkeep b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/1/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/config.pbtxt.j2 b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/config.pbtxt.j2 new file mode 100644 index 00000000..4b719b04 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/config.pbtxt.j2 @@ -0,0 +1,208 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "tensorrtllm" +max_batch_size: 128 + +model_transaction_policy { + decoupled: {{ decoupled_mode }} +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_UINT32 + dims: [ 1 ] + }, + { + name: "end_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_width" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_UINT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "1" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "{{ gpt_model_type }}" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "{{ engine_dir }}" + } +} +parameters: { + key: "max_tokens_in_paged_kv_cache" + value: { + string_value: "" + } +} +parameters: { + key: "batch_scheduler_policy" + value: { + string_value: "guaranteed_completion" + } +} +parameters: { + key: "kv_cache_free_gpu_mem_fraction" + value: { + string_value: ".75" + } +} +parameters: { + key: "max_num_sequences" + value: { + string_value: "" + } +} +parameters: { + key: "enable_trt_overlap" + value: { + string_value: "" + } +} diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/__init__.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/__init__.py new file mode 100644 index 00000000..c5141a38 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/__init__.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Model-Server converts LLMs to TensorRT engines and hosts them with Triton.""" +import argparse +import logging +import os + +from .conversion import ConversionOptions, convert +from .errors import ModelServerException +from .model import Model +from .server import ModelServer + +_LOGGER = logging.getLogger(__name__) + + +def _mk_symlink(source: str, dest: str) -> None: + """Safely create a symbolic link.""" + if not os.path.exists(dest): + os.symlink(source, dest) + _LOGGER.debug("Creating symlink from %s to %s", source, dest) + + +def _azureml(source_directory: str) -> None: + """Make accomodations for AzureML.""" + _LOGGER.info("Detected running on AzureML.") + _LOGGER.debug("AZUREML_MODEL_DIR Variable is %s", source_directory) + destination_directory = "/model" + source_directory = os.path.abspath(source_directory) + _LOGGER.debug("Azure Model Directory is now %s", source_directory) + _LOGGER.debug( + "Azure model directory contents: %s", repr(os.listdir(source_directory)) + ) + + # find the direct path to the model weights + # $AZUREML_MODEL_DIR/model_name + try: + source_directory = os.path.join( + source_directory, os.listdir(source_directory)[0] + ) + _LOGGER.debug("Azure Model Directory is now %s", source_directory) + _LOGGER.debug( + "Azure model directory contents: %s", repr(os.listdir(source_directory)) + ) + except IndexError: + # pylint: disable-next=raise-missing-from + raise ModelServerException("AzureML folder structure is not recognized.") + + # create links for the model files to the /models directory + for root, dirs, files in os.walk(source_directory): + root_destination = root.replace(source_directory, destination_directory) + for fi in files: + _mk_symlink(os.path.join(root, fi), os.path.join(root_destination, fi)) + for di in dirs: + dest = os.path.join(root_destination, di) + os.makedirs(dest, exist_ok=True) + _LOGGER.debug("Creating directory %s", dest) + + +def _should_convert(args: argparse.Namespace, model: "Model") -> bool: + """Determine if the conversion step should run.""" + if args.force_conversion: + return True + + if args.no_conversion: + return False + + return model.conversion_is_needed() + + +def main(args: argparse.Namespace) -> int: + """Execute the model server.""" + # make accomidations for various ML platforms + azureml_model_dir = os.environ.get("AZUREML_MODEL_DIR") + if azureml_model_dir: + _azureml(azureml_model_dir) + + # load the model directory + _LOGGER.info("Reading the model directory.") + model = Model(model_type=args.type, world_size=args.world_size) + + # calculate the default parallism parameters + if not args.tensor_parallelism: + args.tensor_parallelism = max( + int(model.world_size / args.pipeline_parallelism), 1 + ) + if args.pipeline_parallelism * args.tensor_parallelism != model.world_size: + raise ModelServerException( + "Tensor Parallelism * Pipeline Parallelism must be equal to World Size" + ) + + conversion_opts = ConversionOptions( + max_input_length=args.max_input_length, + max_output_length=args.max_output_length, + tensor_parallelism=args.tensor_parallelism, + pipline_parallelism=args.pipeline_parallelism, + ) + + # print discovered model parameters + _LOGGER.info("Model file format: %s", model.format.name) + _LOGGER.info("World Size: %d", model.world_size) + _LOGGER.info("Compute Capability: %s", model.compute_cap) + + # convert model + if _should_convert(args, model): + _LOGGER.info("Starting TensorRT Conversion.") + convert(model, conversion_opts) + else: + _LOGGER.info("TensorRT Conversion not required. Skipping.") + + # host model + if not args.no_hosting: + _LOGGER.info("Starting Triton Inference Server.") + inference_server = ModelServer(model, args.http) + return inference_server.run() + + return 0 diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/__main__.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/__main__.py new file mode 100644 index 00000000..b64791d7 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/__main__.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Main entrypoint for the model-server application.""" +import argparse +import logging +import os +import sys + +from . import main +from .errors import ModelServerException +from .model import ModelTypes + +TERMINATION_LOG = "/dev/termination-log" + +_LOG_FMT = f"[{os.getpid()}] %(asctime)15s [%(levelname)7s] - %(name)s - %(message)s" +_LOG_DATE_FMT = "%b %d %H:%M:%S" +_LOGGER = logging.getLogger("main") + + +def parse_args() -> argparse.Namespace: + """Parse the comamnd line arguments.""" + parser = argparse.ArgumentParser( + prog="model-server", + description="Ingest models and host them with NVIDIA TensorRT LLM", + ) + + # options + parser.add_argument( + "-w", + "--world-size", + default=None, + type=int, + help="The number of GPUs to shard the model across. " + + "By default, this value will be equal to the number of available GPUs.", + ) + parser.add_argument( + "--force-conversion", + action="store_true", + help="When this flag is set, the TensorRT engine conversion will occur, " + + "even if a valid engine is in the cache.", + ) + parser.add_argument( + "--no-conversion", + action="store_true", + help="Skip the conversion. If no engine is available in the cache, an error will be raised.", + ) + parser.add_argument( + "--no-hosting", + action="store_true", + help="Do not start the Triton Inference Server. Only convert the model then exit.", + ) + parser.add_argument( + "-v", + "--verbose", + action="count", + default=1, + help="increase output verbosity", + ) + parser.add_argument( + "-q", + "--quiet", + action="count", + default=0, + help="decrease output verbosity", + ) + + # builder customization + parser.add_argument( + "--max-input-length", + type=int, + default=3000, + help="maximum number of input tokens", + ) + parser.add_argument( + "--max-output-length", + type=int, + default=512, + help="maximum number of output tokens", + ) + parser.add_argument( + "--tensor-parallelism", + type=int, + default=None, + help="number of tensor parallelism divisions (default: world_size/pipeline_parallelism)", + ) + parser.add_argument( + "--pipeline-parallelism", + type=int, + default=1, + help="number of pipeline parallism divisions (default: 1)", + ) + + # server customization + parser.add_argument( + "--http", + action="store_true", + help="change the api server to http instead of grpc (note: this will disable token streaming)", + ) + + # positional arguments + supported_model_types = [e.name.lower().replace("_", "-") for e in ModelTypes] + parser.add_argument( + "type", + metavar="TYPE", + choices=supported_model_types, + type=str.lower, + help=f"{supported_model_types} The type of model to process.", + ) + + args = parser.parse_args() + + if args.force_conversion and args.no_conversion: + parser.error("--force_conversion and --no-conversion are mutually exclusive.") + + return args + + +def _bootstrap_logging(verbosity: int = 0) -> None: + """Configure Python's logger according to the given verbosity level. + + :param verbosity: The desired verbosity level. Must be one of 0, 1, or 2. + :type verbosity: typing.Literal[0, 1, 2] + """ + # determine log level + verbosity = min(2, max(0, verbosity)) # limit verbosity to 0-2 + log_level = [logging.WARN, logging.INFO, logging.DEBUG][verbosity] + + # configure python's logger + logging.basicConfig(format=_LOG_FMT, datefmt=_LOG_DATE_FMT, level=log_level) + # update existing loggers + _LOGGER.setLevel(log_level) + # pylint: disable-next=no-member; false positive + for logger_name in logging.root.manager.loggerDict: + logger = logging.getLogger(logger_name) + for handler in logger.handlers: + handler.setFormatter(logging.Formatter(fmt=_LOG_FMT, datefmt=_LOG_DATE_FMT)) + + +def _k8s_error_handler(err: Exception) -> None: + """When running in Kubernetes, write errors to the termination log.""" + with open(TERMINATION_LOG, "w", encoding="UTF-8") as term_log: + # recursively write nested exceptions + def _write_errors_to_term_log(e: BaseException) -> None: + term_log.write(f"{type(e)}: {e}\n") + if e.__cause__: + _write_errors_to_term_log(e.__cause__) + + _write_errors_to_term_log(err) + + +def _error_handler(err: Exception) -> int: + """Catch and handle exceptions from the applicaiton.""" + # keybaord interrupts are fine + if isinstance(err, KeyboardInterrupt): + return 0 + + # on k8s, write errors to log file + if os.path.isfile(TERMINATION_LOG): + _k8s_error_handler(err) + + # raise uncaught errors + if not isinstance(err, ModelServerException): + raise err + + # gracefully handle caught errors + _LOGGER.error(str(err)) + + # if there is a nested error, raise it + if err.__cause__: + raise err.__cause__ + + # we decided to quite gracefully + return 1 + + +if __name__ == "__main__": + try: + _ARGS = parse_args() + _bootstrap_logging(_ARGS.verbose - _ARGS.quiet) + sys.exit(main(_ARGS)) + # pylint: disable-next=broad-exception-caught; Error handling based on type is done in the handler + except Exception as _ERR: + sys.exit(_error_handler(_ERR)) diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/__init__.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/__init__.py new file mode 100644 index 00000000..e32153fa --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/__init__.py @@ -0,0 +1,73 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the logic for doing model conversions to TensorRT.""" +from dataclasses import dataclass +from typing import Optional + +from ..errors import ModelServerException +from ..model import Model, ModelFormats, ModelTypes + + +@dataclass +class ConversionOptions: + """Class containing the options used in TRT conversion.""" + + max_input_length: int + max_output_length: int + pipline_parallelism: int + tensor_parallelism: int + vocab_size: Optional[int] = None + + +def convert(model: Model, opts: ConversionOptions) -> None: + """ + Convert the provided model to TensorRT. + + Supported types and formats: + +----------+---------+---------+---------+---------+---------+ + | | NEMO | PYTORCH | ONNX | HFACE | UNKNOWN | + +----------+---------+---------+---------+---------+---------+ + | LLAMA | ✅ | ✅ | ❌ | ✅ | ❌ | + | GPTNEXT | ✅ | ❌ | ❌ | ❌ | ❌ | + +----------+---------+---------+---------+---------+---------+ + """ + if model.format == ModelFormats.NEMO: + # pylint: disable-next=import-outside-toplevel # preventing circular imports + from . import nemo + + nemo.convert(model, opts) + + elif model.type == ModelTypes.LLAMA: + # pylint: disable-next=import-outside-toplevel # preventing circular imports + from . import llama + + opts.vocab_size = 32000 + llama.convert(model, opts) + + elif model.type == ModelTypes.CODE_LLAMA: + # pylint: disable-next=import-outside-toplevel # preventing circular imports + from . import llama + + opts.vocab_size = 32016 + llama.convert(model, opts) + + else: + supported_types = [e.name for e in ModelTypes] + raise ModelServerException( + f"Unsupported model type. Conversion is supported for the following types: {supported_types}" + ) + + model.write_hash() diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/llama.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/llama.py new file mode 100644 index 00000000..5dc21552 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/llama.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the logic for exporting a Llama model in PyTorch format to TensorRT.""" +import logging +import os +import subprocess +import sys +import typing + +from ..errors import ModelServerException, UnsupportedFormatException +from ..model import Model +from . import ConversionOptions + +_CONVERSION_SCRIPTS = "/opt/conversion_scripts/llama" + +_CHECKPOINT_ARGS_FLAGS = {"PYTORCH": "--meta_ckpt_dir", "HUGGINGFACE": "--model_dir"} +_LOGGER = logging.getLogger(__name__) + + +def convert(model: Model, opts: ConversionOptions) -> None: + """Convert a llama model.""" + _LOGGER.debug("Running Llama model conversion.") + + # construct builder executable path + cwd = _CONVERSION_SCRIPTS + exe = [sys.executable, "build.py"] + + # construct builder env variables + env = os.environ + + # construct builder arguments + try: + raw_args: typing.List[str] = [ + "--max_input_len", + str(opts.max_input_length), + "--max_output_len", + str(opts.max_output_length), + "--dtype", + "float16", + "--use_gpt_attention_plugin", + "float16", + "--use_inflight_batching", + "--paged_kv_cache", + "--remove_input_padding", + "--use_gemm_plugin", + "float16", + "--output_dir", + model.engine_dir, + "--world_size", + str(model.world_size), + "--tp_size", + str(opts.tensor_parallelism), + "--pp_size", + str(opts.pipline_parallelism), + "--vocab_size", + str(opts.vocab_size), + _CHECKPOINT_ARGS_FLAGS[model.format.name], + model.model_dir, + ] + except KeyError as err: + raise UnsupportedFormatException( + model.format.name, ["PyTorch", "Hugging Face"] + ) from err + + # start the builder + _LOGGER.debug( + "Starting Llama exporter with the command: %s", " ".join(exe + raw_args) + ) + _LOGGER.debug("Starting Llama exporter with the env vars: %s", repr(env)) + with subprocess.Popen(exe + raw_args, env=env, cwd=cwd) as proc: + try: + retcode = proc.wait() + except KeyboardInterrupt: + proc.kill() + except Exception as err: + raise ModelServerException( + "Error running TensorRT model conversion." + ) from err + else: + if retcode != 0: + raise ModelServerException( + "TensorRT conversion returned a non-zero exit code." + ) diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/nemo.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/nemo.py new file mode 100644 index 00000000..d0eb10f3 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/nemo.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the code for converting any .nemo formatted model to TRT LLM.""" +import logging +import os +from glob import glob +from tarfile import TarFile +from typing import IO, cast + +import yaml + +# pylint: disable-next=import-error +from nemo.export import TensorRTLLM # type: ignore + +from ..errors import ModelServerException +from ..model import Model +from . import ConversionOptions + +_LOGGER = logging.getLogger(__name__) + + +def convert(model: Model, _: ConversionOptions) -> None: + """Convert a .nemo formatted model.""" + # find the .nemo model file + model_files = glob(os.path.join(model.model_dir, "*.nemo")) + if len(model_files) > 1: + raise ModelServerException( + "More than one NeMo checkpoint found in the model directory. " + + "Please only include one NeMo checkpoint file." + ) + + # verify that the model parallelism matchines the + config = {} + with TarFile(model_files[0], "r") as archive: + try: + config_file = cast(IO[bytes], archive.extractfile("./model_config.yaml")) + except KeyError: + config_file = cast(IO[bytes], archive.extractfile("model_config.yaml")) + config = yaml.safe_load(config_file) + config_file.close() + + if config.get("tensor_model_parallel_size", 1) != model.world_size: + raise ModelServerException( + f"The provided model has a tensor parallelism of {config.get('tensor_model_parallel_size', 1)} " + + f"and the server has been requested to use {model.world_size} " + + "gpus. Please use the NeMo inference container to rezise the parallelism of the model or change " + + "the model-server's world size." + ) + + # run the nemo to trt llm conversion + trt_llm_exporter = TensorRTLLM(model_dir=model.engine_dir) + _LOGGER.info(".nemo to TensorRT Conversion started. This will take a few minutes.") + _LOGGER.info(model.engine_dir) + trt_llm_exporter.export( + nemo_checkpoint_path=model_files[0], + model_type=model.family, + n_gpus=model.world_size, + ) diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/errors.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/errors.py new file mode 100644 index 00000000..5609a58a --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/errors.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""The custom errors raised by the model server.""" +import typing + + +class ModelServerException(Exception): + """The base class for any custom expections.""" + + +class UnsupportedFormatException(ModelServerException): + """An error that indicates the model format is not supported for the provided type.""" + + def __init__(self, model_type: str, supported: typing.List[str]): + """Initialize the exception.""" + super().__init__( + "Unsupported model type and format combination. " + + f"{model_type} models are supported in the following formats: {str(supported)}" + ) diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/model.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/model.py new file mode 100644 index 00000000..0f83459b --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/model.py @@ -0,0 +1,246 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the model class that represents the model mounted to the container.""" +import glob +import hashlib +import logging +import os +import pathlib +import subprocess +import typing +from enum import Enum, auto, unique + +from .errors import ModelServerException + +DEFAULT_MODEL_DIR = "/model" +HASH_COMMAND = "sha1sum" +_LOGGER = logging.getLogger(__name__) + + +def _fast_hash_dir(dir_path: str) -> str: + """ + Read the files in a directory and quickly create a hash. + + This hash IS NOT cryptographically secure, but it is designed to be computed as quickly as reasonably possible. + This function will only hash top level files and will not traverse directories. + """ + # create a threaded pool of workers to calculate individual hases + workers = [] + for obj in os.listdir(dir_path): + obj_path = os.path.join(dir_path, obj) + if not os.path.isfile(obj_path): + continue + + workers += [ + # pylint: disable-next=consider-using-with + subprocess.Popen( + [HASH_COMMAND, obj_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + ] + + # wait for workers to complete + all_shas = b"" + for proc in workers: + stdout, _ = proc.communicate() + all_shas += stdout.split(b" ", maxsplit=1)[0] + + hasher = hashlib.sha1(usedforsecurity=False) + hasher.update(all_shas) + return hasher.hexdigest() + + +@unique +class ModelFormats(Enum): + """A Enumerator containing all of the supported model types.""" + + UNKNOWN = auto() + ONNX = auto() + PYTORCH = auto() + HUGGINGFACE = auto() + NEMO = auto() + + +@unique +class ModelTypes(Enum): + """A enumerator of the supported model types.""" + + LLAMA = auto() + CODE_LLAMA = auto() + GPTNEXT = auto() + + @property + def family(self) -> str: + """Return the family grouping of the model.""" + return ["llama", "llama", "gptnext"][self.value - 1] + + +class Model: + """A representation of the mounted model.""" + + def __init__( + self, + model_type: str, + model_dir: typing.Optional[str] = None, + world_size: typing.Optional[int] = None, + ): + """Initialize the model class.""" + try: + self._type = ModelTypes[model_type.upper().replace("-", "_")] + except KeyError as err: + raise ModelServerException(f"Unrecognized model type {type}") from err + + self._model_dir = model_dir or DEFAULT_MODEL_DIR + self._gpu_info = self._init_gpu_info(world_size=world_size) + self._hash: typing.Optional[str] = None + self._engine_dir = self._init_engine_dir() + self._format = self._init_model_format() + + @classmethod + def _init_gpu_info( + cls, + world_size: typing.Optional[int] = None, + ) -> typing.Dict[str, typing.Union[str, int]]: + """ + Get the product name and architecture for the first GPU in the system. + + Returns + ------- + Tuple: A tuple of the product name and architecture. + """ + query_cmd = ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"] + gpu_info_raw = subprocess.check_output(query_cmd) + compute_caps = [cap.decode() for cap in gpu_info_raw.strip().split(b"\n")] + # FUTURE: convert this to use nvml instead + + # do basic error checking + if len(compute_caps) == 0: + raise ModelServerException("No GPUs attached to the container.") + if len(set(compute_caps)) > 1: + raise ModelServerException( + "Attached GPUs are dissimilar. All GPUs must be of the same type." + ) + if not world_size: + world_size = len(compute_caps) + + return {"compute_cap": compute_caps[0], "world_size": world_size} + + def _init_engine_dir(self) -> str: + """Create and return the path to the TensorRT cache directory for this model.""" + cache_dir = f"trt-w{self.world_size}-cc{self.compute_cap}" + cache_path = os.path.join(self.model_dir, cache_dir) + pathlib.Path(cache_path).mkdir(parents=True, exist_ok=True) + return cache_path + + def _init_model_format(self) -> ModelFormats: + """Determine the format of model that has been mounted.""" + # look for nemo checkpoints + nemo_count = self._file_ext_count("nemo") + if nemo_count == 1: + return ModelFormats.NEMO + if nemo_count > 1: + raise ModelServerException( + f"Only one nemo checkpoint file may be in the model directory. Found {nemo_count}", + ) + + # look for pytorch saved models + pytorch_count = self._file_ext_count("pth") + self._file_ext_count("pt") + if pytorch_count: + return ModelFormats.PYTORCH + + # look for huggingface saved models + hf_count = self._file_ext_count("bin") + if hf_count: + return ModelFormats.HUGGINGFACE + + # look for onnx models + onnx_count = self._file_ext_count("onnx") + if onnx_count: + return ModelFormats.ONNX + + return ModelFormats.UNKNOWN + + def _file_ext_count(self, extension: str) -> int: + """Count the files in a directory with a given extension.""" + path = os.path.join(self.model_dir, f"*.{extension}") + return len(glob.glob(path)) + + @property + def type(self) -> ModelTypes: + """Return the type of the model.""" + return self._type + + @property + def family(self) -> str: + """Return the model family grouping.""" + return self._type.family + + @property + def model_dir(self) -> str: + """Return the stored model directory.""" + return self._model_dir + + @property + def engine_dir(self) -> str: + """Return the stored engine directory.""" + return self._engine_dir + + @property + def world_size(self) -> int: + """Return the world size.""" + ws = self._gpu_info["world_size"] + return typing.cast(int, ws) + + @property + def compute_cap(self) -> str: + """Return the compute capability version.""" + cc = self._gpu_info["compute_cap"] + return typing.cast(str, cc) + + @property + def format(self) -> ModelFormats: + """Return the format of the model.""" + return self._format + + @property + def hash(self) -> str: + """Return the hash of the model.""" + if not self._hash: + _LOGGER.info("Calculating model hash.") + self._hash = _fast_hash_dir(self.model_dir) + return self._hash + + @property + def _last_hash_path(self) -> str: + """Return the path to the last known hash file.""" + return os.path.join(self.engine_dir, "hash") + + def conversion_is_needed(self) -> bool: + """Determine if the engine conversion is required.""" + if not os.path.isfile(self._last_hash_path): + _LOGGER.debug("No engine file exists. Will generate an engine file.") + return True + with open(self._last_hash_path, "r", encoding="ASCII") as hash_file: + last_hash = hash_file.read() + if last_hash != self.hash: + _LOGGER.debug("Change in model hash detected. Will regnerate engine file.") + return True + _LOGGER.debug("Existing engine file found.") + return False + + def write_hash(self) -> None: + """Write the model hash to the engine directory.""" + with open(self._last_hash_path, "w", encoding="ASCII") as hash_file: + hash_file.write(self.hash) diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/server.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/server.py new file mode 100644 index 00000000..a9077bf9 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/server.py @@ -0,0 +1,153 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""This module contains the code to statup triton inference servers.""" +import logging +import os +import subprocess +import typing + +from jinja2 import Environment, FileSystemLoader + +from .model import Model, ModelFormats + +_ENSEMBLE_MODEL_DIR = "/opt/ensemble_models" +_TRITON_BIN = "/opt/tritonserver/bin/tritonserver" +_MPIRUN_BIN = "/usr/local/mpi/bin/mpirun" +_LOGGER = logging.getLogger(__name__) + + +class ModelServer: + """Abstraction of a multi-gpu triton inference server cluster.""" + + def __init__(self, model: Model, http: bool = False) -> None: + """Initialize the model server.""" + self._model = model + self._http = http + + @property + def _decoupled_mode(self) -> str: + """Indicate if the Triton models should be hosted in decoupled mode for streaming.""" + if self._model.format == ModelFormats.NEMO: + return "false" + return "true" if not self._http else "false" + + @property + def _allow_http(self) -> str: + """Indicate if Triton should allow http connections.""" + return "true" if self._http else "false" + + @property + def _allow_grpc(self) -> str: + """Inidicate if Triton should allow grpc connections.""" + return "true" if not self._http else "false" + + @property + def _tokenizer_model_dir(self) -> str: + """Inidicate where the tokenizer model can be found.""" + if self._model.format == ModelFormats.NEMO: + return self._model.engine_dir + return self._model.model_dir + + @property + def _gpt_model_type(self) -> str: + """Indicate the TRT LLM Backend mode.""" + if self._model.format == ModelFormats.NEMO: + return "V1" + return "inflight_fused_batching" + + @property + def model_repository(self) -> str: + """Return the triton model repository.""" + return os.path.join(_ENSEMBLE_MODEL_DIR, self._model.family) + + def _triton_server_cmd(self, rank: int) -> typing.List[str]: + """Generate the command to start a single triton server of given rank.""" + return [ + "-n", + "1", + _TRITON_BIN, + "--allow-http", + self._allow_http, + "--allow-grpc", + self._allow_grpc, + "--model-repository", + self.model_repository, + "--disable-auto-complete-config", + f"--backend-config=python,shm-region-prefix-name=prefix{rank}_", + ":", + ] + + @property + def _cmd(self) -> typing.List[str]: + """Generate the full command.""" + cmd = [_MPIRUN_BIN] + for rank in range(self._model.world_size): + cmd += self._triton_server_cmd(rank) + return cmd + + @property + def _env(self) -> typing.Dict[str, str]: + """Return the environment variable for the triton inference server.""" + env = dict(os.environ) + env["TRT_ENGINE_DIR"] = self._model.engine_dir + env["TOKENIZER_DIR"] = self._tokenizer_model_dir + if os.getuid() == 0: + _LOGGER.warning( + "Triton server will be running as root. It is recommended that you don't run this container as root." + ) + env["OMPI_ALLOW_RUN_AS_ROOT"] = "1" + env["OMPI_ALLOW_RUN_AS_ROOT_CONFIRM"] = "1" + return env + + def _render_model_templates(self) -> None: + """Render and Jinja templates in the model directory.""" + env = Environment( + loader=FileSystemLoader(searchpath=self.model_repository), + autoescape=False, + ) # nosec; all the provided values are from code, not the user + + template_path = os.path.join("tensorrt_llm", "config.pbtxt.j2") + output_path = os.path.join( + self.model_repository, "tensorrt_llm", "config.pbtxt" + ) + + template = env.get_template(template_path) + + with open(output_path, "w", encoding="UTF-8") as out: + template_args = { + "engine_dir": self._model.engine_dir, + "decoupled_mode": self._decoupled_mode, + "gpt_model_type": self._gpt_model_type, + } + out.write(template.render(**template_args)) + + def run(self) -> int: + """Start the triton inference server.""" + cmd = self._cmd + env = self._env + + _LOGGER.debug("Rendering the ensemble models.") + self._render_model_templates() + + _LOGGER.debug("Starting triton with the command: %s", " ".join(cmd)) + _LOGGER.debug("Starting triton with the env vars: %s", repr(env)) + with subprocess.Popen(cmd, env=env) as proc: + try: + retcode = proc.wait() + except KeyboardInterrupt: + proc.kill() + return 0 + return retcode diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server_client/trt_llm.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server_client/trt_llm.py new file mode 100644 index 00000000..7291db4c --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server_client/trt_llm.py @@ -0,0 +1,544 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""A Langchain LLM component for connecting to Triton + TensorRT LLM backend.""" +# pylint: disable=too-many-lines +import abc +import json +import queue +import random +import time +from functools import partial +from typing import Any, Callable, Dict, List, Optional, Type, Union + +import google.protobuf.json_format +import numpy as np +import tritonclient.grpc as grpcclient +import tritonclient.http as httpclient +from tritonclient.grpc.service_pb2 import ModelInferResponse +from tritonclient.utils import np_to_triton_dtype + +try: + from langchain.callbacks.manager import CallbackManagerForLLMRun + from langchain.llms.base import LLM + from langchain.pydantic_v1 import Field, root_validator + + USE_LANGCHAIN = True +except ImportError: + USE_LANGCHAIN = False + + +STOP_WORDS = [""] +RANDOM_SEED = 0 + +if USE_LANGCHAIN: + # pylint: disable-next=too-few-public-methods # Interface is defined by LangChain + class TensorRTLLM(LLM): # type: ignore # LLM class not typed in langchain + """A custom Langchain LLM class that integrates with TRTLLM triton models. + + Arguments: + server_url: (str) The URL of the Triton inference server to use. + model_name: (str) The name of the Triton TRT model to use. + temperature: (str) Temperature to use for sampling + top_p: (float) The top-p value to use for sampling + top_k: (float) The top k values use for sampling + beam_width: (int) Last n number of tokens to penalize + repetition_penalty: (int) Last n number of tokens to penalize + length_penalty: (float) The penalty to apply repeated tokens + tokens: (int) The maximum number of tokens to generate. + client: The client object used to communicate with the inference server + """ + + server_url: str = Field(None, alias="server_url") + + # # all the optional arguments + model_name: str = "ensemble" + temperature: Optional[float] = 1.0 + top_p: Optional[float] = 0 + top_k: Optional[int] = 1 + tokens: Optional[int] = 100 + beam_width: Optional[int] = 1 + repetition_penalty: Optional[float] = 1.0 + length_penalty: Optional[float] = 1.0 + client: Any + streaming: Optional[bool] = True + + @root_validator() # type: ignore # typing not declared in langchain + @classmethod + def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate that python package exists in environment.""" + try: + if values.get("streaming", True): + values["client"] = GrpcTritonClient(values["server_url"]) + else: + values["client"] = HttpTritonClient(values["server_url"]) + + except ImportError as err: + raise ImportError( + "Could not import triton client python package. " + "Please install it with `pip install tritonclient[all]`." + ) from err + return values + + @property + def _get_model_default_parameters(self) -> Dict[str, Any]: + return { + "tokens": self.tokens, + "top_k": self.top_k, + "top_p": self.top_p, + "temperature": self.temperature, + "repetition_penalty": self.repetition_penalty, + "length_penalty": self.length_penalty, + "beam_width": self.beam_width, + } + + @property + def _invocation_params(self, **kwargs: Any) -> Dict[str, Any]: + params = {**self._get_model_default_parameters, **kwargs} + return params + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get all the identifying parameters.""" + return { + "server_url": self.server_url, + "model_name": self.model_name, + } + + @property + def _llm_type(self) -> str: + return "triton_tensorrt" + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, # pylint: disable=unused-argument + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """ + Execute an inference request. + + Args: + prompt: The prompt to pass into the model. + stop: A list of strings to stop generation when encountered + + Returns: + The string generated by the model + """ + text_callback = None + if run_manager: + text_callback = partial( + run_manager.on_llm_new_token, verbose=self.verbose + ) + + invocation_params = self._get_model_default_parameters + invocation_params.update(kwargs) + invocation_params["prompt"] = [[prompt]] + model_params = self._identifying_params + model_params.update(kwargs) + request_id = str(random.randint(1, 9999999)) # nosec + + self.client.load_model(model_params["model_name"]) + if isinstance(self.client, GrpcTritonClient): + return self._streaming_request( + model_params, request_id, invocation_params, text_callback + ) + return self._request(model_params, invocation_params, text_callback) + + def _streaming_request( + self, + model_params: Dict[str, Any], + request_id: str, + invocation_params: Dict[str, Any], + text_callback: Optional[Callable[[str], None]], + ) -> str: + """Request a streaming inference session.""" + result_queue = self.client.request_streaming( + model_params["model_name"], request_id, **invocation_params + ) + + response = "" + for token in result_queue: + if text_callback: + text_callback(token) + response = response + token + return response + + def _request( + self, + model_params: Dict[str, Any], + invocation_params: Dict[str, Any], + text_callback: Optional[Callable[[str], None]], + ) -> str: + """Request a streaming inference session.""" + token: str = self.client.request( + model_params["model_name"], **invocation_params + ) + if text_callback: + text_callback(token) + return token + + +class StreamingResponseGenerator(queue.Queue[Optional[str]]): + """A Generator that provides the inference results from an LLM.""" + + def __init__( + self, client: "GrpcTritonClient", request_id: str, force_batch: bool + ) -> None: + """Instantiate the generator class.""" + super().__init__() + self._client = client + self.request_id = request_id + self._batch = force_batch + + def __iter__(self) -> "StreamingResponseGenerator": + """Return self as a generator.""" + return self + + def __next__(self) -> str: + """Return the next retrieved token.""" + val = self.get() + if val is None or val in STOP_WORDS: + self._stop_stream() + raise StopIteration() + return val + + def _stop_stream(self) -> None: + """Drain and shutdown the Triton stream.""" + self._client.stop_stream( + "tensorrt_llm", self.request_id, signal=not self._batch + ) + + +class _BaseTritonClient(abc.ABC): + """An abstraction of the connection to a triton inference server.""" + + def __init__(self, server_url: str) -> None: + """Initialize the client.""" + self._server_url = server_url + self._client = self._inference_server_client(server_url) + + @property + @abc.abstractmethod + def _inference_server_client( + self, + ) -> Union[ + Type[grpcclient.InferenceServerClient], Type[httpclient.InferenceServerClient] + ]: + """Return the prefered InferenceServerClient class.""" + + @property + @abc.abstractmethod + def _infer_input( + self, + ) -> Union[Type[grpcclient.InferInput], Type[httpclient.InferInput]]: + """Return the preferred InferInput.""" + + @property + @abc.abstractmethod + def _infer_output( + self, + ) -> Union[ + Type[grpcclient.InferRequestedOutput], Type[httpclient.InferRequestedOutput] + ]: + """Return the preferred InferRequestedOutput.""" + + def load_model(self, model_name: str, timeout: int = 1000) -> None: + """Load a model into the server.""" + if self._client.is_model_ready(model_name): + return + + self._client.load_model(model_name) + t0 = time.perf_counter() + t1 = t0 + while not self._client.is_model_ready(model_name) and t1 - t0 < timeout: + t1 = time.perf_counter() + + if not self._client.is_model_ready(model_name): + raise RuntimeError(f"Failed to load {model_name} on Triton in {timeout}s") + + def get_model_list(self) -> List[str]: + """Get a list of models loaded in the triton server.""" + res = self._client.get_model_repository_index(as_json=True) + return [model["name"] for model in res["models"]] + + def get_model_concurrency(self, model_name: str, timeout: int = 1000) -> int: + """Get the modle concurrency.""" + self.load_model(model_name, timeout) + instances = self._client.get_model_config(model_name, as_json=True)["config"][ + "instance_group" + ] + return sum(instance["count"] * len(instance["gpus"]) for instance in instances) + + def _generate_stop_signals( + self, + ) -> List[Union[grpcclient.InferInput, httpclient.InferInput]]: + """Generate the signal to stop the stream.""" + inputs = [ + self._infer_input("input_ids", [1, 1], "INT32"), + self._infer_input("input_lengths", [1, 1], "INT32"), + self._infer_input("request_output_len", [1, 1], "UINT32"), + self._infer_input("stop", [1, 1], "BOOL"), + ] + inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32)) + inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32)) + inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32)) + inputs[3].set_data_from_numpy(np.array([[True]], dtype="bool")) + return inputs + + def _generate_outputs( + self, + ) -> List[Union[grpcclient.InferRequestedOutput, httpclient.InferRequestedOutput]]: + """Generate the expected output structure.""" + return [self._infer_output("text_output")] + + def _prepare_tensor( + self, name: str, input_data: Any + ) -> Union[grpcclient.InferInput, httpclient.InferInput]: + """Prepare an input data structure.""" + t = self._infer_input( + name, input_data.shape, np_to_triton_dtype(input_data.dtype) + ) + t.set_data_from_numpy(input_data) + return t + + def _generate_inputs( # pylint: disable=too-many-arguments,too-many-locals + self, + prompt: str, + tokens: int = 300, + temperature: float = 1.0, + top_k: float = 1, + top_p: float = 0, + beam_width: int = 1, + repetition_penalty: float = 1, + length_penalty: float = 1.0, + stream: bool = True, + ) -> List[Union[grpcclient.InferInput, httpclient.InferInput]]: + """Create the input for the triton inference server.""" + query = np.array(prompt).astype(object) + request_output_len = np.array([tokens]).astype(np.uint32).reshape((1, -1)) + runtime_top_k = np.array([top_k]).astype(np.uint32).reshape((1, -1)) + runtime_top_p = np.array([top_p]).astype(np.float32).reshape((1, -1)) + temperature_array = np.array([temperature]).astype(np.float32).reshape((1, -1)) + len_penalty = np.array([length_penalty]).astype(np.float32).reshape((1, -1)) + repetition_penalty_array = ( + np.array([repetition_penalty]).astype(np.float32).reshape((1, -1)) + ) + random_seed = np.array([RANDOM_SEED]).astype(np.uint64).reshape((1, -1)) + beam_width_array = np.array([beam_width]).astype(np.uint32).reshape((1, -1)) + streaming_data = np.array([[stream]], dtype=bool) + + inputs = [ + self._prepare_tensor("text_input", query), + self._prepare_tensor("max_tokens", request_output_len), + self._prepare_tensor("top_k", runtime_top_k), + self._prepare_tensor("top_p", runtime_top_p), + self._prepare_tensor("temperature", temperature_array), + self._prepare_tensor("length_penalty", len_penalty), + self._prepare_tensor("repetition_penalty", repetition_penalty_array), + self._prepare_tensor("random_seed", random_seed), + self._prepare_tensor("beam_width", beam_width_array), + self._prepare_tensor("stream", streaming_data), + ] + return inputs + + def _trim_batch_response(self, result_str: str) -> str: + """Trim the resulting response from a batch request by removing provided prompt and extra generated text.""" + # extract the generated part of the prompt + split = result_str.split("[/INST]", 1) + generated = split[-1] + end_token = generated.find("") + if end_token == -1: + return generated + generated = generated[:end_token].strip() + return generated + + +class GrpcTritonClient(_BaseTritonClient): + """GRPC connection to a triton inference server.""" + + @property + def _inference_server_client( + self, + ) -> Type[grpcclient.InferenceServerClient]: + """Return the prefered InferenceServerClient class.""" + return grpcclient.InferenceServerClient # type: ignore + + @property + def _infer_input(self) -> Type[grpcclient.InferInput]: + """Return the preferred InferInput.""" + return grpcclient.InferInput # type: ignore + + @property + def _infer_output( + self, + ) -> Type[grpcclient.InferRequestedOutput]: + """Return the preferred InferRequestedOutput.""" + return grpcclient.InferRequestedOutput # type: ignore + + def _send_stop_signals(self, model_name: str, request_id: str) -> None: + """Send the stop signal to the Triton Inference server.""" + stop_inputs = self._generate_stop_signals() + self._client.async_stream_infer( + model_name, + stop_inputs, + request_id=request_id, + parameters={"Streaming": True}, + ) + + @staticmethod + def _process_result(result: Dict[str, str]) -> str: + """Post-process the result from the server.""" + message = ModelInferResponse() + generated_text: str = "" + google.protobuf.json_format.Parse(json.dumps(result), message) + infer_result = grpcclient.InferResult(message) + np_res = infer_result.as_numpy("text_output") + + generated_text = "" + if np_res is not None: + generated_text = "".join([token.decode() for token in np_res]) + + return generated_text + + def _stream_callback( + self, + result_queue: queue.Queue[Union[Optional[Dict[str, str]], str]], + force_batch: bool, + result: Any, + error: str, + ) -> None: + """Add streamed result to queue.""" + if error: + result_queue.put(error) + else: + response_raw = result.get_response(as_json=True) + if "outputs" in response_raw: + # the very last response might have no output, just the final flag + response = self._process_result(response_raw) + if force_batch: + response = self._trim_batch_response(response) + + if response in STOP_WORDS: + result_queue.put(None) + else: + result_queue.put(response) + + if response_raw["parameters"]["triton_final_response"]["bool_param"]: + # end of the generation + result_queue.put(None) + + # pylint: disable-next=too-many-arguments + def _send_prompt_streaming( + self, + model_name: str, + request_inputs: Any, + request_outputs: Optional[Any], + request_id: str, + result_queue: StreamingResponseGenerator, + force_batch: bool = False, + ) -> None: + """Send the prompt and start streaming the result.""" + self._client.start_stream( + callback=partial(self._stream_callback, result_queue, force_batch) + ) + self._client.async_stream_infer( + model_name=model_name, + inputs=request_inputs, + outputs=request_outputs, + request_id=request_id, + ) + + def request_streaming( + self, + model_name: str, + request_id: Optional[str] = None, + force_batch: bool = False, + **params: Any, + ) -> StreamingResponseGenerator: + """Request a streaming connection.""" + if not self._client.is_model_ready(model_name): + raise RuntimeError("Cannot request streaming, model is not loaded") + + if not request_id: + request_id = str(random.randint(1, 9999999)) # nosec + + result_queue = StreamingResponseGenerator(self, request_id, force_batch) + inputs = self._generate_inputs(stream=not force_batch, **params) + outputs = self._generate_outputs() + self._send_prompt_streaming( + model_name, + inputs, + outputs, + request_id, + result_queue, + force_batch, + ) + return result_queue + + def stop_stream( + self, model_name: str, request_id: str, signal: bool = True + ) -> None: + """Close the streaming connection.""" + if signal: + self._send_stop_signals(model_name, request_id) + self._client.stop_stream() + + +class HttpTritonClient(_BaseTritonClient): + """HTTP connection to a triton inference server.""" + + @property + def _inference_server_client( + self, + ) -> Type[httpclient.InferenceServerClient]: + """Return the prefered InferenceServerClient class.""" + return httpclient.InferenceServerClient # type: ignore + + @property + def _infer_input(self) -> Type[httpclient.InferInput]: + """Return the preferred InferInput.""" + return httpclient.InferInput # type: ignore + + @property + def _infer_output( + self, + ) -> Type[httpclient.InferRequestedOutput]: + """Return the preferred InferRequestedOutput.""" + return httpclient.InferRequestedOutput # type: ignore + + def request( + self, + model_name: str, + **params: Any, + ) -> str: + """Request inferencing from the triton server.""" + if not self._client.is_model_ready(model_name): + raise RuntimeError("Cannot request streaming, model is not loaded") + + # create model inputs and outputs + inputs = self._generate_inputs(stream=False, **params) + outputs = self._generate_outputs() + + # call the model for inference + result = self._client.infer(model_name, inputs=inputs, outputs=outputs) + result_str = "".join( + [val.decode("utf-8") for val in result.as_numpy("text_output").tolist()] + ) + + # extract the generated part of the prompt + # return(result_str) + return self._trim_batch_response(result_str) diff --git a/RetrievalAugmentedGeneration/llm-inference-server/requirements.txt b/RetrievalAugmentedGeneration/llm-inference-server/requirements.txt new file mode 100644 index 00000000..3f6a7734 --- /dev/null +++ b/RetrievalAugmentedGeneration/llm-inference-server/requirements.txt @@ -0,0 +1,7 @@ +jinja2 +langchain +numpy +protobuf +requests +tritonclient[all] +pyyaml diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 00000000..35180962 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,24 @@ + ## Security + +NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization. + +If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub.** + +## Reporting Potential Security Vulnerability in an NVIDIA Product + +To report a potential security vulnerability in any NVIDIA product: +- Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html) +- E-Mail: psirt@nvidia.com + - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key) + - Please include the following information: + - Product/Driver name and version/branch that contains the vulnerability + - Type of vulnerability (code execution, denial of service, buffer overflow, etc.) + - Instructions to reproduce the vulnerability + - Proof-of-concept or exploit code + - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability + +While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information. + +## NVIDIA Product Security + +For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security \ No newline at end of file