diff --git a/RetrievalAugmentedGeneration/LICENSE-Apache-2.0.txt b/RetrievalAugmentedGeneration/LICENSE-Apache-2.0.txt
new file mode 100644
index 00000000..08017da8
--- /dev/null
+++ b/RetrievalAugmentedGeneration/LICENSE-Apache-2.0.txt
@@ -0,0 +1,177 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+END OF TERMS AND CONDITIONS
diff --git a/RetrievalAugmentedGeneration/LICENSE.md b/RetrievalAugmentedGeneration/LICENSE.md
new file mode 100644
index 00000000..89bdbd99
--- /dev/null
+++ b/RetrievalAugmentedGeneration/LICENSE.md
@@ -0,0 +1,14 @@
+SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+SPDX-License-Identifier: Apache-2.0
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
\ No newline at end of file
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/Dockerfile b/RetrievalAugmentedGeneration/llm-inference-server/Dockerfile
new file mode 100644
index 00000000..43fa8e79
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/Dockerfile
@@ -0,0 +1,22 @@
+ARG BASE_IMAGE_URL=nvcr.io/ea-bignlp/beta-inf-prerelease/infer
+ARG BASE_IMAGE_TAG=23.10.v3
+
+FROM ${BASE_IMAGE_URL}:${BASE_IMAGE_TAG}
+
+ENV LD_LIBRARY_PATH=/opt/tritonserver/backends/tensorrtllm:$LD_LIBRARY_PATH
+
+# install model-server automation
+COPY conversion_scripts /opt/conversion_scripts
+COPY ensemble_models /opt/ensemble_models
+COPY model_server /opt/model_server
+COPY model_server_client /opt/model_server_client
+RUN --mount=type=bind,source=requirements.txt,target=/opt/requirements.txt \
+ pip install --no-cache-dir -r /opt/requirements.txt
+
+# Create basic directories
+
+RUN mkdir /model && chmod 1777 /model && \
+ mkdir -p /home/triton-server && chown 1000:1000 /home/triton-server && chmod 700 /home/triton-server
+
+WORKDIR /opt
+ENTRYPOINT ["/usr/bin/python3", "-m", "model_server"]
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/build.py b/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/build.py
new file mode 100644
index 00000000..2a57d4cd
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/build.py
@@ -0,0 +1,776 @@
+# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import argparse
+import json
+import os
+import time
+from pathlib import Path
+
+import tensorrt as trt
+import tensorrt_llm
+import torch
+import torch.multiprocessing as mp
+from tensorrt_llm._utils import str_dtype_to_trt
+from tensorrt_llm.builder import Builder
+from tensorrt_llm.layers.attention import PositionEmbeddingType
+from tensorrt_llm.logger import logger
+from tensorrt_llm.mapping import Mapping
+from tensorrt_llm.models import (
+ fp8_quantize,
+ smooth_quantize,
+ weight_only_groupwise_quantize,
+ weight_only_quantize,
+)
+from tensorrt_llm.network import net_guard
+from tensorrt_llm.plugin.plugin import ContextFMHAType
+from tensorrt_llm.quantization import QuantMode
+from transformers import LlamaConfig, LlamaForCausalLM
+from weight import (
+ get_scaling_factors,
+ load_from_awq_llama,
+ load_from_binary,
+ load_from_gptq_llama,
+ load_from_hf_llama,
+ load_from_meta_llama,
+)
+
+from weight import parse_ft_config # isort:skip
+
+MODEL_NAME = "llama"
+
+# 2 routines: get_engine_name, serialize_engine
+# are direct copy from gpt example, TODO: put in utils?
+
+import onnx
+import tensorrt as trt
+from onnx import TensorProto, helper
+
+
+def trt_dtype_to_onnx(dtype):
+ if dtype == trt.float16:
+ return TensorProto.DataType.FLOAT16
+ elif dtype == trt.float32:
+ return TensorProto.DataType.FLOAT
+ elif dtype == trt.int32:
+ return TensorProto.DataType.INT32
+ else:
+ raise TypeError("%s is not supported" % dtype)
+
+
+def to_onnx(network, path):
+ inputs = []
+ for i in range(network.num_inputs):
+ network_input = network.get_input(i)
+ inputs.append(
+ helper.make_tensor_value_info(
+ network_input.name,
+ trt_dtype_to_onnx(network_input.dtype),
+ list(network_input.shape),
+ )
+ )
+
+ outputs = []
+ for i in range(network.num_outputs):
+ network_output = network.get_output(i)
+ outputs.append(
+ helper.make_tensor_value_info(
+ network_output.name,
+ trt_dtype_to_onnx(network_output.dtype),
+ list(network_output.shape),
+ )
+ )
+
+ nodes = []
+ for i in range(network.num_layers):
+ layer = network.get_layer(i)
+ layer_inputs = []
+ for j in range(layer.num_inputs):
+ ipt = layer.get_input(j)
+ if ipt is not None:
+ layer_inputs.append(layer.get_input(j).name)
+ layer_outputs = [layer.get_output(j).name for j in range(layer.num_outputs)]
+ nodes.append(
+ helper.make_node(
+ str(layer.type),
+ name=layer.name,
+ inputs=layer_inputs,
+ outputs=layer_outputs,
+ domain="com.nvidia",
+ )
+ )
+
+ onnx_model = helper.make_model(
+ helper.make_graph(nodes, "attention", inputs, outputs, initializer=None),
+ producer_name="NVIDIA",
+ )
+ onnx.save(onnx_model, path)
+
+
+def get_engine_name(model, dtype, tp_size, pp_size, rank):
+ if pp_size == 1:
+ return "{}_{}_tp{}_rank{}.engine".format(model, dtype, tp_size, rank)
+ return "{}_{}_tp{}_pp{}_rank{}.engine".format(model, dtype, tp_size, pp_size, rank)
+
+
+def serialize_engine(engine, path):
+ logger.info(f"Serializing engine to {path}...")
+ tik = time.time()
+ with open(path, "wb") as f:
+ f.write(bytearray(engine))
+ tok = time.time()
+ t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
+ logger.info(f"Engine serialized. Total time: {t}")
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--world_size", type=int, default=1)
+ parser.add_argument("--tp_size", type=int, default=1)
+ parser.add_argument("--pp_size", type=int, default=1)
+ parser.add_argument("--model_dir", type=str, default=None)
+ parser.add_argument("--ft_model_dir", type=str, default=None)
+ parser.add_argument("--meta_ckpt_dir", type=str, default=None)
+ parser.add_argument("--quant_ckpt_path", type=str, default=None)
+ parser.add_argument(
+ "--dtype",
+ type=str,
+ default="float16",
+ choices=["float32", "bfloat16", "float16"],
+ )
+ parser.add_argument(
+ "--timing_cache",
+ type=str,
+ default="model.cache",
+ help="The path of to read timing cache from, will be ignored if the file does not exist",
+ )
+ parser.add_argument("--log_level", type=str, default="info")
+ parser.add_argument("--vocab_size", type=int, default=32000)
+ parser.add_argument("--n_layer", type=int, default=32)
+ parser.add_argument("--n_positions", type=int, default=2048)
+ parser.add_argument("--n_embd", type=int, default=4096)
+ parser.add_argument("--n_head", type=int, default=32)
+ parser.add_argument("--n_kv_head", type=int, default=None)
+ parser.add_argument("--multiple_of", type=int, default=256)
+ parser.add_argument("--ffn_dim_multiplier", type=float, default=1.0)
+ parser.add_argument("--inter_size", type=int, default=None)
+ parser.add_argument("--hidden_act", type=str, default="silu")
+ parser.add_argument("--rms_norm_eps", type=float, default=1e-06)
+ parser.add_argument("--max_batch_size", type=int, default=8)
+ parser.add_argument("--max_input_len", type=int, default=2048)
+ parser.add_argument("--max_output_len", type=int, default=512)
+ parser.add_argument("--max_beam_width", type=int, default=1)
+ parser.add_argument("--rotary_base", type=float, default=10000.0)
+ parser.add_argument("--rotary_scaling", nargs=2, type=str, default=None)
+ parser.add_argument(
+ "--use_gpt_attention_plugin",
+ nargs="?",
+ const="float16",
+ type=str,
+ default=False,
+ choices=["float16", "bfloat16", "float32"],
+ )
+ parser.add_argument(
+ "--use_gemm_plugin",
+ nargs="?",
+ const="float16",
+ type=str,
+ default=False,
+ choices=["float16", "bfloat16", "float32"],
+ )
+ parser.add_argument(
+ "--use_rmsnorm_plugin",
+ nargs="?",
+ const="float16",
+ type=str,
+ default=False,
+ choices=["float16", "float32", "bfloat16"],
+ )
+ parser.add_argument("--parallel_build", default=False, action="store_true")
+ parser.add_argument("--enable_context_fmha", default=False, action="store_true")
+ parser.add_argument(
+ "--enable_context_fmha_fp32_acc", default=False, action="store_true"
+ )
+ parser.add_argument("--visualize", default=False, action="store_true")
+ parser.add_argument("--enable_debug_output", default=False, action="store_true")
+ parser.add_argument("--gpus_per_node", type=int, default=8)
+ parser.add_argument("--builder_opt", type=int, default=None)
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="llama_outputs",
+ help="The path to save the serialized engine files, timing cache file and model configs",
+ )
+ parser.add_argument("--remove_input_padding", default=False, action="store_true")
+
+ # Arguments related to the quantization of the model.
+ parser.add_argument(
+ "--use_smooth_quant",
+ default=False,
+ action="store_true",
+ help="Use the SmoothQuant method to quantize activations and weights for the various GEMMs."
+ "See --per_channel and --per_token for finer-grained quantization options.",
+ )
+ parser.add_argument(
+ "--per_channel",
+ default=False,
+ action="store_true",
+ help="By default, we use a single static scaling factor for the GEMM's result. "
+ "per_channel instead uses a different static scaling factor for each channel. "
+ "The latter is usually more accurate, but a little slower.",
+ )
+ parser.add_argument(
+ "--per_token",
+ default=False,
+ action="store_true",
+ help="By default, we use a single static scaling factor to scale activations in the int8 range. "
+ "per_token chooses at run time, and for each token, a custom scaling factor. "
+ "The latter is usually more accurate, but a little slower.",
+ )
+ parser.add_argument(
+ "--per_group",
+ default=False,
+ action="store_true",
+ help="By default, we use a single static scaling factor to scale weights in the int4 range. "
+ "per_group chooses at run time, and for each group, a custom scaling factor. "
+ "The flag is built for GPTQ/AWQ quantization.",
+ )
+ parser.add_argument(
+ "--group_size",
+ type=int,
+ default=128,
+ help="Group size used in GPTQ/AWQ quantization.",
+ )
+ parser.add_argument(
+ "--int8_kv_cache",
+ default=False,
+ action="store_true",
+ help="By default, we use dtype for KV cache. int8_kv_cache chooses int8 quantization for KV",
+ )
+ parser.add_argument(
+ "--use_parallel_embedding",
+ action="store_true",
+ default=False,
+ help="By default embedding parallelism is disabled. By setting this flag, embedding parallelism is enabled",
+ )
+ parser.add_argument(
+ "--embedding_sharding_dim",
+ type=int,
+ default=1, # Meta does TP on hidden dim
+ choices=[0, 1],
+ help="By default the embedding lookup table is sharded along vocab dimension (embedding_sharding_dim=0). "
+ "To shard it along hidden dimension, set embedding_sharding_dim=1"
+ "Note: embedding sharing is only enabled when embedding_sharding_dim = 0",
+ )
+ parser.add_argument(
+ "--enable_fp8",
+ default=False,
+ action="store_true",
+ help="Use FP8 Linear layer for Attention QKV/Dense and MLP.",
+ )
+ parser.add_argument(
+ "--fp8_kv_cache",
+ default=False,
+ action="store_true",
+ help="By default, we use dtype for KV cache. fp8_kv_cache chooses int8 quantization for KV",
+ )
+ parser.add_argument(
+ "--quantized_fp8_model_path",
+ type=str,
+ default=None,
+ help="Path of a quantized model checkpoint in .npz format",
+ )
+ parser.add_argument(
+ "--use_weight_only",
+ default=False,
+ action="store_true",
+ help="Quantize weights for the various GEMMs to INT4/INT8."
+ "See --weight_only_precision to set the precision",
+ )
+ parser.add_argument(
+ "--weight_only_precision",
+ const="int8",
+ type=str,
+ nargs="?",
+ default="int8",
+ choices=["int8", "int4", "int4_awq", "int4_gptq"],
+ help="Define the precision for the weights when using weight-only quantization."
+ "You must also use --use_weight_only for that argument to have an impact.",
+ )
+ parser.add_argument(
+ "--use_inflight_batching",
+ action="store_true",
+ default=False,
+ help="Activates inflight batching mode of gptAttentionPlugin.",
+ )
+ parser.add_argument(
+ "--paged_kv_cache",
+ action="store_true",
+ default=False,
+ help="By default we use contiguous KV cache. By setting this flag you enable paged KV cache",
+ )
+ parser.add_argument(
+ "--tokens_per_block",
+ type=int,
+ default=64,
+ help="Number of tokens per block in paged KV cache",
+ )
+ parser.add_argument(
+ "--max_num_tokens",
+ type=int,
+ default=None,
+ help="Define the max number of tokens supported by the engine",
+ )
+ parser.add_argument(
+ "--strongly_typed",
+ default=False,
+ action="store_true",
+ help="This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.",
+ )
+ parser.add_argument(
+ "--use_custom_all_reduce",
+ action="store_true",
+ help="Activates latency-optimized algorithm for all-reduce instead of NCCL.",
+ )
+
+ args = parser.parse_args()
+ tensorrt_llm.logger.set_level(args.log_level)
+
+ assert not (
+ args.use_smooth_quant and args.use_weight_only
+ ), "You cannot enable both SmoothQuant and INT8 weight-only together."
+
+ if not args.remove_input_padding:
+ if args.use_gpt_attention_plugin:
+ logger.warning(
+ f"It is recommended to specify --remove_input_padding when using GPT attention plugin"
+ )
+
+ if args.use_inflight_batching:
+ if not args.use_gpt_attention_plugin:
+ args.use_gpt_attention_plugin = "float16"
+ logger.info(
+ f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'"
+ )
+ if not args.remove_input_padding:
+ args.remove_input_padding = True
+ logger.info("Using remove input padding for inflight batching mode.")
+ if not args.paged_kv_cache:
+ args.paged_kv_cache = True
+ logger.info("Using paged KV cache for inflight batching mode.")
+
+ if args.use_smooth_quant:
+ args.quant_mode = QuantMode.use_smooth_quant(args.per_token, args.per_channel)
+ elif args.use_weight_only:
+ if args.per_group:
+ args.quant_mode = QuantMode.from_description(
+ quantize_weights=True,
+ quantize_activations=False,
+ per_token=False,
+ per_channel=False,
+ per_group=True,
+ use_int4_weights=True,
+ )
+ else:
+ args.quant_mode = QuantMode.use_weight_only(
+ args.weight_only_precision == "int4"
+ )
+ else:
+ args.quant_mode = QuantMode(0)
+
+ if args.int8_kv_cache:
+ args.quant_mode = args.quant_mode.set_int8_kv_cache()
+ elif args.fp8_kv_cache:
+ args.quant_mode = args.quant_mode.set_fp8_kv_cache()
+ if args.enable_fp8:
+ args.quant_mode = args.quant_mode.set_fp8_qdq()
+
+ if args.rotary_scaling is not None:
+ rotary_scaling = {
+ "type": args.rotary_scaling[0],
+ "factor": float(args.rotary_scaling[1]),
+ }
+ assert rotary_scaling["type"] in ["linear", "dynamic"]
+ assert rotary_scaling["factor"] > 1.0
+ args.rotary_scaling = rotary_scaling
+ if rotary_scaling["type"] == "dynamic":
+ assert not args.remove_input_padding, "TODO: Not supported yet"
+
+ # Since gpt_attenttion_plugin is the only way to apply RoPE now,
+ # force use the plugin for now with the correct data type.
+ args.use_gpt_attention_plugin = args.dtype
+ if args.model_dir is not None:
+ hf_config = LlamaConfig.from_pretrained(args.model_dir)
+ args.inter_size = (
+ hf_config.intermediate_size
+ ) # override the inter_size for LLaMA
+ args.n_embd = hf_config.hidden_size
+ args.n_head = hf_config.num_attention_heads
+ if hasattr(hf_config, "num_key_value_heads"):
+ args.n_kv_head = hf_config.num_key_value_heads
+ args.n_layer = hf_config.num_hidden_layers
+ args.n_positions = hf_config.max_position_embeddings
+ args.vocab_size = hf_config.vocab_size
+ args.hidden_act = hf_config.hidden_act
+ args.rms_norm_eps = hf_config.rms_norm_eps
+ elif args.meta_ckpt_dir is not None:
+ with open(Path(args.meta_ckpt_dir, "params.json")) as fp:
+ meta_config: dict = json.load(fp)
+ args.n_embd = meta_config["dim"]
+ args.n_head = meta_config["n_heads"]
+ args.n_layer = meta_config["n_layers"]
+ args.n_kv_head = meta_config.get("n_kv_heads", args.n_head)
+ args.multiple_of = meta_config["multiple_of"]
+ args.ffn_dim_multiplier = meta_config.get("ffn_dim_multiplier", 1)
+ n_embd = int(4 * args.n_embd * 2 / 3)
+ args.inter_size = args.multiple_of * (
+ (int(n_embd * args.ffn_dim_multiplier) + args.multiple_of - 1)
+ // args.multiple_of
+ )
+ args.rms_norm_eps = meta_config["norm_eps"]
+ elif args.ft_model_dir is not None:
+ (
+ n_embd,
+ n_head,
+ n_layer,
+ n_positions,
+ vocab_size,
+ hidden_act,
+ inter_size,
+ n_kv_head,
+ ) = parse_ft_config(Path(args.ft_model_dir) / "config.ini")
+ args.inter_size = inter_size # override the inter_size for LLaMA
+ args.n_kv_head = n_kv_head
+ args.n_embd = n_embd
+ args.n_head = n_head
+ args.n_layer = n_layer
+ args.n_positions = n_positions
+ args.vocab_size = vocab_size
+ args.hidden_act = hidden_act
+ args.rms_norm_eps = 1e-06
+ logger.warning("Set rms_norm_eps to 1e-06 directly.")
+ assert args.use_gpt_attention_plugin, "LLaMa must use gpt attention plugin"
+ if args.n_kv_head is None:
+ args.n_kv_head = args.n_head
+ elif args.n_kv_head != args.n_head:
+ assert (
+ args.n_head % args.n_kv_head
+ ) == 0, "MQA/GQA requires the number of heads to be divisible by the number of K/V heads."
+ assert (args.n_kv_head % args.tp_size) == 0 or (
+ args.tp_size % args.n_kv_head
+ ) == 0, (
+ "MQA/GQA requires either the number of K/V heads to be divisible by the tensor parallelism size OR "
+ "the tensor parallelism size to be divisible by the number of K/V heads."
+ )
+
+ if args.dtype == "bfloat16":
+ assert args.use_gemm_plugin, "Please use gemm plugin when dtype is bfloat16"
+
+ assert args.pp_size * args.tp_size == args.world_size
+
+ if args.max_num_tokens is not None:
+ assert args.enable_context_fmha
+
+ if args.inter_size is None:
+ # this should not be need when loading a real model
+ # but it is helpful when creating a dummy model without loading any real weights
+ n_embd = int(4 * args.n_embd * 2 / 3)
+ args.inter_size = args.multiple_of * (
+ (int(n_embd * args.ffn_dim_multiplier) + args.multiple_of - 1)
+ // args.multiple_of
+ )
+ logger.info(f"Setting inter_size to {args.inter_size}.")
+
+ return args
+
+
+def build_rank_engine(
+ builder: Builder,
+ builder_config: tensorrt_llm.builder.BuilderConfig,
+ engine_name,
+ rank,
+ args,
+):
+ """
+ @brief: Build the engine on the given rank.
+ @param rank: The rank to build the engine.
+ @param args: The cmd line arguments.
+ @return: The built engine.
+ """
+ dtype = str_dtype_to_trt(args.dtype)
+ mapping = Mapping(
+ world_size=args.world_size,
+ rank=rank,
+ tp_size=args.tp_size,
+ pp_size=args.pp_size,
+ )
+
+ assert (
+ args.n_layer % args.pp_size == 0
+ ), f"num_layers {args.n_layer} must be a multiple of pipeline parallelism size {args.pp_size}"
+
+ # Initialize Module
+ tensorrt_llm_llama = tensorrt_llm.models.LLaMAForCausalLM(
+ num_layers=args.n_layer,
+ num_heads=args.n_head,
+ num_kv_heads=args.n_kv_head,
+ hidden_size=args.n_embd,
+ vocab_size=args.vocab_size,
+ hidden_act=args.hidden_act,
+ max_position_embeddings=args.n_positions,
+ dtype=dtype,
+ mlp_hidden_size=args.inter_size,
+ position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
+ mapping=mapping,
+ rotary_base=args.rotary_base,
+ rotary_scaling=args.rotary_scaling,
+ use_parallel_embedding=args.use_parallel_embedding,
+ embedding_sharding_dim=args.embedding_sharding_dim,
+ quant_mode=args.quant_mode,
+ rms_norm_eps=args.rms_norm_eps,
+ )
+ if args.use_smooth_quant:
+ tensorrt_llm_llama = smooth_quantize(tensorrt_llm_llama, args.quant_mode)
+ elif args.use_weight_only:
+ if args.weight_only_precision == "int8":
+ tensorrt_llm_llama = weight_only_quantize(
+ tensorrt_llm_llama, args.quant_mode
+ )
+ elif args.weight_only_precision == "int4":
+ tensorrt_llm_llama = weight_only_quantize(
+ tensorrt_llm_llama, args.quant_mode
+ )
+ elif args.weight_only_precision == "int4_awq":
+ tensorrt_llm_llama = weight_only_groupwise_quantize(
+ model=tensorrt_llm_llama,
+ quant_mode=args.quant_mode,
+ group_size=args.group_size,
+ zero=False,
+ pre_quant_scale=True,
+ exclude_modules=[],
+ )
+ elif args.weight_only_precision == "int4_gptq":
+ tensorrt_llm_llama = weight_only_groupwise_quantize(
+ model=tensorrt_llm_llama,
+ quant_mode=args.quant_mode,
+ group_size=args.group_size,
+ zero=True,
+ pre_quant_scale=False,
+ )
+ elif args.enable_fp8 or args.fp8_kv_cache:
+ logger.info(f"Loading scaling factors from " f"{args.quantized_fp8_model_path}")
+ quant_scales = get_scaling_factors(
+ args.quantized_fp8_model_path,
+ num_layers=args.n_layer,
+ quant_mode=args.quant_mode,
+ )
+ tensorrt_llm_llama = fp8_quantize(
+ tensorrt_llm_llama, quant_mode=args.quant_mode, quant_scales=quant_scales
+ )
+ if args.per_group:
+ load_func = (
+ load_from_awq_llama
+ if args.weight_only_precision == "int4_awq"
+ else load_from_gptq_llama
+ )
+ load_func(
+ tensorrt_llm_llama=tensorrt_llm_llama,
+ quant_ckpt_path=args.quant_ckpt_path,
+ mapping=mapping,
+ dtype=args.dtype,
+ )
+ elif args.meta_ckpt_dir is not None:
+ load_from_meta_llama(
+ tensorrt_llm_llama, args.meta_ckpt_dir, mapping, args.dtype
+ )
+ elif args.model_dir is not None:
+ logger.info(f"Loading HF LLaMA ... from {args.model_dir}")
+ tik = time.time()
+ hf_llama = LlamaForCausalLM.from_pretrained(
+ args.model_dir,
+ device_map={"model": "cpu", "lm_head": "cpu"}, # Load to CPU memory
+ torch_dtype="auto",
+ )
+ tok = time.time()
+ t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
+ logger.info(f"HF LLaMA loaded. Total time: {t}")
+ load_from_hf_llama(
+ tensorrt_llm_llama, hf_llama, mapping=mapping, dtype=args.dtype
+ )
+ del hf_llama
+ elif args.ft_model_dir is not None:
+ load_from_binary(
+ tensorrt_llm_llama,
+ args.ft_model_dir,
+ mapping,
+ fp16=(args.dtype == "float16"),
+ multi_query_mode=(args.n_kv_head != args.n_head),
+ )
+
+ # Module -> Network
+ network = builder.create_network()
+ network.trt_network.name = engine_name
+ if args.use_gpt_attention_plugin:
+ network.plugin_config.set_gpt_attention_plugin(
+ dtype=args.use_gpt_attention_plugin
+ )
+ if args.use_gemm_plugin:
+ network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
+ if args.use_rmsnorm_plugin:
+ network.plugin_config.set_rmsnorm_plugin(dtype=args.use_rmsnorm_plugin)
+
+ # Quantization plugins.
+ if args.use_smooth_quant:
+ network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
+ network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype)
+ network.plugin_config.set_quantize_tensor_plugin()
+ network.plugin_config.set_quantize_per_token_plugin()
+ assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
+ if args.enable_context_fmha:
+ network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
+ if args.enable_context_fmha_fp32_acc:
+ network.plugin_config.set_context_fmha(ContextFMHAType.enabled_with_fp32_acc)
+ if args.use_weight_only:
+ if args.per_group:
+ network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
+ dtype="float16"
+ )
+ else:
+ network.plugin_config.set_weight_only_quant_matmul_plugin(dtype="float16")
+ if args.world_size > 1:
+ network.plugin_config.set_nccl_plugin(args.dtype, args.use_custom_all_reduce)
+ if args.remove_input_padding:
+ network.plugin_config.enable_remove_input_padding()
+ if args.paged_kv_cache:
+ network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
+
+ with net_guard(network):
+ # Prepare
+ network.set_named_parameters(tensorrt_llm_llama.named_parameters())
+
+ # Forward
+ inputs = tensorrt_llm_llama.prepare_inputs(
+ args.max_batch_size,
+ args.max_input_len,
+ args.max_output_len,
+ True,
+ args.max_beam_width,
+ args.max_num_tokens,
+ )
+ tensorrt_llm_llama(*inputs)
+ if args.enable_debug_output:
+ # mark intermediate nodes' outputs
+ for k, v in tensorrt_llm_llama.named_network_outputs():
+ v = v.trt_tensor
+ v.name = k
+ network.trt_network.mark_output(v)
+ v.dtype = dtype
+ if args.visualize:
+ model_path = os.path.join(args.output_dir, "test.onnx")
+ to_onnx(network.trt_network, model_path)
+
+ tensorrt_llm.graph_rewriting.optimize(network)
+
+ engine = None
+
+ # Network -> Engine
+ engine = builder.build_engine(network, builder_config)
+ if rank == 0:
+ config_path = os.path.join(args.output_dir, "config.json")
+ builder.save_config(builder_config, config_path)
+ return engine
+
+
+def build(rank, args):
+ torch.cuda.set_device(rank % args.gpus_per_node)
+ logger.set_level(args.log_level)
+ if not os.path.exists(args.output_dir):
+ os.makedirs(args.output_dir)
+
+ # when doing serializing build, all ranks share one engine
+ builder = Builder()
+
+ cache = None
+ for cur_rank in range(args.world_size):
+ # skip other ranks if parallel_build is enabled
+ if args.parallel_build and cur_rank != rank:
+ continue
+ # NOTE: when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT
+ int8_trt_flag = args.quant_mode.has_act_and_weight_quant() or (
+ not args.paged_kv_cache and args.quant_mode.has_int8_kv_cache()
+ )
+ builder_config = builder.create_builder_config(
+ name=MODEL_NAME,
+ precision=args.dtype,
+ timing_cache=args.timing_cache if cache is None else cache,
+ tensor_parallel=args.tp_size,
+ pipeline_parallel=args.pp_size,
+ parallel_build=args.parallel_build,
+ num_layers=args.n_layer,
+ num_heads=args.n_head,
+ num_kv_heads=args.n_kv_head,
+ hidden_size=args.n_embd,
+ vocab_size=args.vocab_size,
+ hidden_act=args.hidden_act,
+ max_position_embeddings=args.n_positions,
+ max_batch_size=args.max_batch_size,
+ max_input_len=args.max_input_len,
+ max_output_len=args.max_output_len,
+ max_num_tokens=args.max_num_tokens,
+ int8=int8_trt_flag,
+ fp8=args.quant_mode.has_fp8_qdq(),
+ quant_mode=args.quant_mode,
+ strongly_typed=args.strongly_typed,
+ opt_level=args.builder_opt,
+ )
+ engine_name = get_engine_name(
+ MODEL_NAME, args.dtype, args.tp_size, args.pp_size, cur_rank
+ )
+ engine = build_rank_engine(builder, builder_config, engine_name, cur_rank, args)
+ assert engine is not None, f"Failed to build engine for rank {cur_rank}"
+
+ if cur_rank == 0:
+ # Use in-memory timing cache for multiple builder passes.
+ if not args.parallel_build:
+ cache = builder_config.trt_builder_config.get_timing_cache()
+
+ serialize_engine(engine, os.path.join(args.output_dir, engine_name))
+
+ if rank == 0:
+ ok = builder.save_timing_cache(
+ builder_config, os.path.join(args.output_dir, "model.cache")
+ )
+ assert ok, "Failed to save timing cache."
+
+
+if __name__ == "__main__":
+ args = parse_arguments()
+ tik = time.time()
+ if (
+ args.parallel_build
+ and args.world_size > 1
+ and torch.cuda.device_count() >= args.world_size
+ ):
+ logger.warning(
+ f"Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free."
+ )
+ mp.spawn(build, nprocs=args.world_size, args=(args,))
+ else:
+ args.parallel_build = False
+ logger.info("Serially build TensorRT engines.")
+ build(0, args)
+
+ tok = time.time()
+ t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
+ logger.info(f"Total time of building all {args.world_size} engines: {t}")
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/weight.py b/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/weight.py
new file mode 100644
index 00000000..692ae67f
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/conversion_scripts/llama/weight.py
@@ -0,0 +1,1446 @@
+# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import configparser
+import time
+from operator import attrgetter
+from pathlib import Path
+from typing import Dict, List, Optional, Union
+
+import numpy as np
+import tensorrt_llm
+import tensorrt_llm.logger as logger
+import torch
+from safetensors import safe_open
+from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy
+from tensorrt_llm.mapping import Mapping
+from tensorrt_llm.models import LLaMAForCausalLM
+from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales
+from tensorrt_llm.quantization import QuantMode
+
+
+def get_scaling_factors(
+ model_path: Union[str, Path],
+ num_layers: int,
+ quant_mode: Optional[QuantMode] = None,
+) -> Optional[Dict[str, List[int]]]:
+ """Get the scaling factors for LLaMA model
+
+ Returns a dictionary of scaling factors for the selected layers of the
+ LLaMA model.
+
+ Args:
+ model_path (str): Path to the quantized LLaMA model
+ layers (list): List of layers to get the scaling factors for. If None,
+ all layers are selected.
+
+ Returns:
+ dict: Dictionary of scaling factors for the selected layers of the
+ LLaMA model.
+
+ example:
+
+ {
+ 'qkv_act': qkv_act_scale,
+ 'qkv_weights': qkv_weights_scale,
+ 'qkv_output' : qkv_outputs_scale,
+ 'dense_act': dense_act_scale,
+ 'dense_weights': dense_weights_scale,
+ 'fc_act': fc_act_scale,
+ 'fc_weights': fc_weights_scale,
+ 'gate_act': gate_act_scale,
+ 'gate_weights': gate_weights_scale,
+ 'proj_act': proj_act_scale,
+ 'proj_weights': proj_weights_scale,
+ }
+ """
+
+ if model_path is None:
+ logger.warning(
+ f"--quantized_fp8_model_path not specified. "
+ f"Initialize quantization scales automatically."
+ )
+ return get_dummy_quant_scales(num_layers)
+ weight_dict = np.load(model_path)
+
+ # yapf: disable
+ scaling_factor = {
+ 'qkv_act': [],
+ 'qkv_weights': [],
+ 'qkv_output': [],
+ 'dense_act': [],
+ 'dense_weights': [],
+ 'fc_act': [],
+ 'fc_weights': [],
+ 'gate_act': [],
+ 'gate_weights': [],
+ 'proj_act': [],
+ 'proj_weights': [],
+ }
+
+ for layer in range(num_layers):
+ scaling_factor['qkv_act'].append(max(
+ weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
+ weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
+ weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
+ ))
+ scaling_factor['qkv_weights'].append(max(
+ weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
+ weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
+ weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
+ ))
+ if quant_mode is not None and quant_mode.has_fp8_kv_cache():
+ # Not calibrarting KV cache.
+ scaling_factor['qkv_output'].append(1.0)
+ scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
+ scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
+ scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
+ scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
+ scaling_factor['gate_act'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:activation_scaling_factor'].item())
+ scaling_factor['gate_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:weights_scaling_factor'].item())
+ scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
+ scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
+ # yapf: enable
+ for k, v in scaling_factor.items():
+ assert (
+ len(v) == num_layers
+ ), f"Expect scaling factor {k} of length {num_layers}, got {len(v)}"
+
+ return scaling_factor
+
+
+def gen_suffix(rank, use_smooth_quant, quant_per_channel):
+ suffix = f"{rank}.bin"
+ if use_smooth_quant:
+ sq_prefix = "int8."
+ if quant_per_channel:
+ sq_prefix += "col."
+ suffix = sq_prefix + suffix
+ return suffix
+
+
+def extract_layer_idx(name):
+ ss = name.split(".")
+ for s in ss:
+ if s.isdigit():
+ return s
+ return None
+
+
+def split(v, tp_size, idx, dim=0):
+ if tp_size == 1:
+ return v
+ if len(v.shape) == 1:
+ return np.ascontiguousarray(np.split(v, tp_size)[idx])
+ else:
+ return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx])
+
+
+def dup_kv_weight(v, num_head, tp_size):
+ assert tp_size % num_head == 0
+ reps = tp_size // num_head
+ head_size = v.shape[0] // num_head
+ v = v.reshape(num_head, head_size, -1)[:, None, :, :].expand(
+ num_head, reps, head_size, v.shape[1]
+ )
+ return v.reshape(num_head * reps * head_size, -1).clone()
+
+
+def parse_ft_config(ini_file):
+ gpt_config = configparser.ConfigParser()
+ gpt_config.read(ini_file)
+
+ n_embd = gpt_config.getint("llama", "hidden_size")
+ n_head = gpt_config.getint("llama", "num_attention_heads")
+ n_layer = gpt_config.getint("llama", "num_hidden_layers")
+ n_positions = gpt_config.getint("llama", "max_position_embeddings")
+ vocab_size = gpt_config.getint("llama", "vocab_size")
+ hidden_act = gpt_config.get("llama", "hidden_act")
+ inter_size = gpt_config.getint("llama", "intermediate_size", fallback=None)
+ n_kv_head = gpt_config.getint("llama", "num_key_value_heads", fallback=None)
+
+ if inter_size is None:
+ inter_size = 4 * n_embd
+
+ return (
+ n_embd,
+ n_head,
+ n_layer,
+ n_positions,
+ vocab_size,
+ hidden_act,
+ inter_size,
+ n_kv_head,
+ )
+
+
+def load_from_hf_llama(
+ tensorrt_llm_llama: tensorrt_llm.models.LLaMAForCausalLM,
+ hf_llama,
+ mapping=Mapping(),
+ dtype="float32",
+):
+ tensorrt_llm.logger.info("Loading weights from HF LLaMA...")
+ tik = time.time()
+
+ quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0))
+ if quant_mode.is_int8_weight_only():
+ plugin_weight_only_quant_type = torch.int8
+ elif quant_mode.is_int4_weight_only():
+ plugin_weight_only_quant_type = torch.quint4x2
+ use_weight_only = quant_mode.is_weight_only()
+ num_kv_heads = tensorrt_llm_llama.num_kv_heads
+ mha_mode = num_kv_heads == tensorrt_llm_llama.num_heads
+
+ model_params = dict(hf_llama.named_parameters())
+ for l in range(hf_llama.config.num_hidden_layers):
+ prefix = f"model.layers.{l}.self_attn."
+ q_weight = model_params[prefix + "q_proj.weight"]
+ k_weight = model_params[prefix + "k_proj.weight"]
+ v_weight = model_params[prefix + "v_proj.weight"]
+ if not mha_mode:
+ head_size = tensorrt_llm_llama.hidden_size // tensorrt_llm_llama.num_heads
+ if num_kv_heads < mapping.tp_size:
+ # duplicate the KV heads up to tensor_parallel
+ k_weight = dup_kv_weight(k_weight, num_kv_heads, mapping.tp_size)
+ v_weight = dup_kv_weight(v_weight, num_kv_heads, mapping.tp_size)
+ assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0
+ assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0
+ qkv_weight = [q_weight, k_weight, v_weight]
+ else:
+ qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
+
+ model_params[prefix + "qkv_proj.weight"] = qkv_weight
+
+ torch_dtype = str_dtype_to_torch(dtype)
+ layers_per_pipeline_stage = hf_llama.config.num_hidden_layers // mapping.pp_size
+ layers_range = list(
+ range(
+ mapping.pp_rank * layers_per_pipeline_stage,
+ (mapping.pp_rank + 1) * layers_per_pipeline_stage,
+ 1,
+ )
+ )
+ for k, v in model_params.items():
+ if isinstance(v, list):
+ v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v]
+ else:
+ v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
+ if "model.embed_tokens.weight" in k:
+ if tensorrt_llm_llama.use_parallel_embedding:
+ v = split(
+ v,
+ mapping.tp_size,
+ mapping.tp_rank,
+ tensorrt_llm_llama.embedding_sharding_dim,
+ )
+ if mapping.is_first_pp_rank():
+ tensorrt_llm_llama.vocab_embedding.weight.value = v
+ elif "model.norm.weight" in k:
+ if mapping.is_last_pp_rank():
+ tensorrt_llm_llama.ln_f.weight.value = v
+ elif "lm_head.weight" in k:
+ if mapping.is_last_pp_rank():
+ tensorrt_llm_llama.lm_head.weight.value = np.ascontiguousarray(
+ split(v, mapping.tp_size, mapping.tp_rank)
+ )
+ else:
+ layer_idx = extract_layer_idx(k)
+ if layer_idx is None or int(layer_idx) not in layers_range:
+ continue
+ idx = int(layer_idx) - mapping.pp_rank * layers_per_pipeline_stage
+ if idx >= tensorrt_llm_llama.num_layers:
+ continue
+ if "input_layernorm.weight" in k:
+ tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = v
+ elif "post_attention_layernorm.weight" in k:
+ dst = tensorrt_llm_llama.layers[idx].post_layernorm.weight
+ dst.value = v
+ elif "self_attn.qkv_proj.weight" in k:
+ dst = tensorrt_llm_llama.layers[idx].attention.qkv.weight
+ if not mha_mode:
+ assert isinstance(v, list) and len(v) == 3
+ wq = split(v[0], mapping.tp_size, mapping.tp_rank)
+ wk = split(v[1], mapping.tp_size, mapping.tp_rank)
+ wv = split(v[2], mapping.tp_size, mapping.tp_rank)
+ split_v = np.concatenate((wq, wk, wv))
+ else:
+ q_emb = v.shape[0] // 3
+ model_emb = v.shape[1]
+ v = v.reshape(3, q_emb, model_emb)
+ split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
+ split_v = split_v.reshape(3 * (q_emb // mapping.tp_size), model_emb)
+ if use_weight_only:
+ v = np.ascontiguousarray(split_v.transpose())
+ (
+ processed_torch_weights,
+ torch_weight_scales,
+ ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
+ torch.tensor(v), plugin_weight_only_quant_type
+ )
+ # workaround for trt not supporting int8 inputs in plugins currently
+ dst.value = processed_torch_weights.view(
+ dtype=torch.float32
+ ).numpy()
+ scales = tensorrt_llm_llama.layers[
+ idx
+ ].attention.qkv.per_channel_scale
+ scales.value = torch_weight_scales.numpy()
+ else:
+ dst.value = np.ascontiguousarray(split_v)
+ elif "self_attn.o_proj.weight" in k:
+ dst = tensorrt_llm_llama.layers[idx].attention.dense.weight
+ split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
+ if use_weight_only:
+ v = np.ascontiguousarray(split_v.transpose())
+ (
+ processed_torch_weights,
+ torch_weight_scales,
+ ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
+ torch.tensor(v), plugin_weight_only_quant_type
+ )
+ # workaround for trt not supporting int8 inputs in plugins currently
+ dst.value = processed_torch_weights.view(
+ dtype=torch.float32
+ ).numpy()
+ scales = tensorrt_llm_llama.layers[
+ idx
+ ].attention.dense.per_channel_scale
+ scales.value = torch_weight_scales.numpy()
+ else:
+ dst.value = np.ascontiguousarray(split_v)
+ elif "mlp.up_proj.weight" in k:
+ dst = tensorrt_llm_llama.layers[idx].mlp.gate.weight
+ split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0)
+ if use_weight_only:
+ v = np.ascontiguousarray(split_v.transpose())
+ (
+ processed_torch_weights,
+ torch_weight_scales,
+ ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
+ torch.tensor(v), plugin_weight_only_quant_type
+ )
+ # workaround for trt not supporting int8 inputs in plugins currently
+ dst.value = processed_torch_weights.view(
+ dtype=torch.float32
+ ).numpy()
+ scales = tensorrt_llm_llama.layers[idx].mlp.gate.per_channel_scale
+ scales.value = torch_weight_scales.numpy()
+ else:
+ dst.value = np.ascontiguousarray(split_v)
+ elif "mlp.down_proj.weight" in k:
+ dst = tensorrt_llm_llama.layers[idx].mlp.proj.weight
+ split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
+ if use_weight_only:
+ v = np.ascontiguousarray(split_v.transpose())
+ (
+ processed_torch_weights,
+ torch_weight_scales,
+ ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
+ torch.tensor(v), plugin_weight_only_quant_type
+ )
+ # workaround for trt not supporting int8 inputs in plugins currently
+ dst.value = processed_torch_weights.view(
+ dtype=torch.float32
+ ).numpy()
+ scales = tensorrt_llm_llama.layers[idx].mlp.proj.per_channel_scale
+ scales.value = torch_weight_scales.numpy()
+ else:
+ dst.value = np.ascontiguousarray(split_v)
+ elif "mlp.gate_proj.weight" in k:
+ dst = tensorrt_llm_llama.layers[idx].mlp.fc.weight
+ split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0)
+ if use_weight_only:
+ v = np.ascontiguousarray(split_v.transpose())
+ (
+ processed_torch_weights,
+ torch_weight_scales,
+ ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
+ torch.tensor(v), plugin_weight_only_quant_type
+ )
+ # workaround for trt not supporting int8 inputs in plugins currently
+ dst.value = processed_torch_weights.view(
+ dtype=torch.float32
+ ).numpy()
+ scales = tensorrt_llm_llama.layers[idx].mlp.fc.per_channel_scale
+ scales.value = torch_weight_scales.numpy()
+ else:
+ dst.value = np.ascontiguousarray(split_v)
+
+ tok = time.time()
+ t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
+ tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}")
+ return
+
+
+def load_from_meta_llama(
+ tensorrt_llm_llama: tensorrt_llm.models.LLaMAForCausalLM,
+ meta_ckpt_dir,
+ mapping=Mapping(),
+ dtype="float32",
+):
+ torch_dtype = str_dtype_to_torch(dtype)
+
+ def gather_ckpts(ckpts):
+ gathered = {}
+ for k in ckpts[0]:
+ d = 0
+ if any([n in k for n in ["wo", "w2", "tok"]]):
+ d = 1
+ if "norm" in k or "rope" in k: # no TP
+ gathered[k] = ckpts[0][k].clone()
+ else:
+ gathered[k] = torch.cat([pt[k] for pt in ckpts], dim=d).clone()
+ return gathered
+
+ def split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank):
+ split_ckpt = {}
+ for k in ckpt:
+ d = 0
+ if any([n in k for n in ["wo", "w2", "tok"]]):
+ d = 1
+ if "norm" in k or "rope" in k: # no TP
+ split_ckpt[k] = ckpt[k].clone()
+ elif tensorrt_llm_llama.num_kv_heads < mapping.tp_size and any(
+ [n in k for n in ["wk", "wv"]]
+ ):
+ assert mapping.tp_size % tensorrt_llm_llama.num_kv_heads == 0
+ # special case: we need to duplicate KV head
+ tmp = dup_kv_weight(
+ ckpt[k], tensorrt_llm_llama.num_kv_heads, mapping.tp_size
+ )
+ split_ckpt[k] = torch.split(tmp, tmp.shape[d] // ranks_per_ckpt, dim=d)[
+ ckpt_rank
+ ].clone()
+ else:
+ split_ckpt[k] = torch.split(
+ ckpt[k], ckpt[k].shape[d] // ranks_per_ckpt, dim=d
+ )[ckpt_rank].clone()
+ return split_ckpt
+
+ def get_current_weights(num_ckpts):
+ if num_ckpts > mapping.tp_size:
+ # combine ckpts
+ assert (num_ckpts % mapping.tp_size) == 0
+ nf = num_ckpts // mapping.tp_size
+ fs = nf * mapping.tp_rank
+ file_ids = list(range(fs, fs + nf))
+ ckpts = []
+ for f in file_ids:
+ ckpt = torch.load(
+ Path(meta_ckpt_dir, f"consolidated.{f:02d}.pth"), map_location="cpu"
+ )
+ ckpts.append(ckpt)
+ return gather_ckpts(ckpts)
+ elif num_ckpts < mapping.tp_size:
+ # split ckpt
+ assert (mapping.tp_size % num_ckpts) == 0
+ ranks_per_ckpt = mapping.tp_size // num_ckpts
+ ckpt_fid = mapping.tp_rank // ranks_per_ckpt
+ ckpt_rank = mapping.tp_rank % ranks_per_ckpt
+ nH_per_ckpt = tensorrt_llm_llama.num_heads // num_ckpts
+ assert (nH_per_ckpt % ranks_per_ckpt) == 0
+ ckpt = torch.load(
+ Path(meta_ckpt_dir, f"consolidated.{ckpt_fid:02d}.pth"),
+ map_location="cpu",
+ )
+ return split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank)
+
+ # num_ckpts == tensor_parallel, 1:1 mapping from files to TP
+ return torch.load(
+ Path(meta_ckpt_dir, f"consolidated.{mapping.tp_rank:02d}.pth"),
+ map_location="cpu",
+ )
+
+ def permute(w, nH, d, dH):
+ # due to MQA's wk, nH*dH != d could be true
+ return w.view(nH, dH // 2, 2, d).transpose(1, 2).reshape(nH * dH, d)
+
+ if not hasattr(load_from_meta_llama, "saved_embed"):
+ load_from_meta_llama.saved_embed = None
+
+ def gather_embedding(cur_embed, name: str, num_ckpts):
+ if mapping.tp_size == 1:
+ # even if num_ckpts > 1, get_current_weights will already have it gathered
+ return cur_embed
+ if load_from_meta_llama.saved_embed is None:
+ embeds = [None] * num_ckpts
+ for i in range(num_ckpts):
+ ckpt = torch.load(
+ Path(meta_ckpt_dir, f"consolidated.{i:02d}.pth"), map_location="cpu"
+ )
+ embeds[i] = ckpt[name]
+ embed = torch.cat(embeds, dim=1).to(torch_dtype)
+ load_from_meta_llama.saved_embed = torch_to_numpy(
+ embed
+ ) # cache the embedding, not needed if no refit
+ return load_from_meta_llama.saved_embed
+
+ tensorrt_llm.logger.info("Loading weights from Meta LLaMA checkpoints ...")
+ tik = time.time()
+
+ quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0))
+ if quant_mode.is_int8_weight_only():
+ torch.int8
+ elif quant_mode.is_int4_weight_only():
+ torch.quint4x2
+ quant_mode.is_weight_only()
+ num_kv_heads = tensorrt_llm_llama.num_kv_heads
+ mha_mode = num_kv_heads == tensorrt_llm_llama.num_heads
+
+ ckpts = list(Path(meta_ckpt_dir).glob("consolidated.*.pth"))
+ num_ckpts = len(ckpts)
+ # llama/llama2 doesn't have MQA. So, simplifying loader logic by not worrying about it.
+ assert (
+ num_kv_heads > 1 or num_kv_heads >= num_ckpts
+ ), f"We don't know how the {num_kv_heads} KV heads are distributed among {num_ckpts} checkpoints."
+
+ head_size = tensorrt_llm_llama.hidden_size // tensorrt_llm_llama.num_heads
+ ckpt = get_current_weights(num_ckpts)
+ layers_range = list(
+ range(
+ mapping.pp_rank * tensorrt_llm_llama.num_layers,
+ (mapping.pp_rank + 1) * tensorrt_llm_llama.num_layers,
+ 1,
+ )
+ )
+
+ for l in layers_range:
+ prefix = f"layers.{l}.attention."
+ q_weight = permute(
+ ckpt[prefix + "wq.weight"].clone(),
+ nH=(tensorrt_llm_llama.num_heads // mapping.tp_size),
+ d=tensorrt_llm_llama.hidden_size,
+ dH=head_size,
+ )
+ if num_kv_heads < mapping.tp_size and num_ckpts >= mapping.tp_size:
+ assert mapping.tp_size % num_kv_heads == 0
+ assert False, "Not supported yet"
+ k_weight = permute(
+ ckpt[prefix + "wk.weight"].clone(),
+ nH=((num_kv_heads + mapping.tp_size - 1) // mapping.tp_size),
+ d=tensorrt_llm_llama.hidden_size,
+ dH=head_size,
+ )
+ v_weight = ckpt[prefix + "wv.weight"].clone()
+
+ qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
+ ckpt[prefix + "qkv.weight"] = qkv_weight
+
+ for k, v in ckpt.items():
+ v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
+ if "tok_embeddings" in k:
+ if not tensorrt_llm_llama.use_parallel_embedding:
+ v = gather_embedding(v, k, num_ckpts)
+ elif tensorrt_llm_llama.embedding_sharding_dim == 0:
+ # this needs a gather and then resplit along different dims
+ v = gather_embedding(v, k, num_ckpts)
+ v = split(v, mapping.tp_size, mapping.tp_rank, 0)
+ if mapping.is_first_pp_rank():
+ tensorrt_llm_llama.vocab_embedding.weight.value = v
+ elif "output" in k:
+ if mapping.is_last_pp_rank():
+ tensorrt_llm_llama.lm_head.weight.value = v
+ elif k == "norm.weight":
+ if mapping.is_last_pp_rank():
+ tensorrt_llm_llama.ln_f.weight.value = v
+ else:
+ # layer specific weights
+ layer_idx = extract_layer_idx(k)
+ if layer_idx is None:
+ continue
+ idx = int(layer_idx) - mapping.pp_rank * tensorrt_llm_llama.num_layers
+ if idx >= tensorrt_llm_llama.num_layers:
+ continue
+ if "attention_norm.weight" in k:
+ tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = v
+ elif "ffn_norm.weight" in k:
+ tensorrt_llm_llama.layers[idx].post_layernorm.weight.value = v
+ elif "feed_forward.w3.weight" in k:
+ tensorrt_llm_llama.layers[idx].mlp.gate.weight.value = v
+ elif "feed_forward.w2.weight" in k:
+ tensorrt_llm_llama.layers[idx].mlp.proj.weight.value = v
+ elif "feed_forward.w1.weight" in k:
+ tensorrt_llm_llama.layers[idx].mlp.fc.weight.value = v
+ elif "attention.wo.weight" in k:
+ tensorrt_llm_llama.layers[idx].attention.dense.weight.value = v
+ elif "attention.qkv.weight" in k:
+ tensorrt_llm_llama.layers[idx].attention.qkv.weight.value = v
+
+ tok = time.time()
+ t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
+ tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}")
+ return
+
+
+def load_from_binary(
+ tensorrt_llm_llama: LLaMAForCausalLM,
+ dir_path,
+ mapping=Mapping(),
+ fp16=False,
+ multi_query_mode=False,
+):
+ tensorrt_llm.logger.info("Loading weights from FT...")
+ tik = time.time()
+
+ quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0))
+
+ (
+ n_embd,
+ n_head,
+ n_layer,
+ n_positions,
+ vocab_size,
+ hidden_act,
+ inter_size,
+ n_kv_head,
+ ) = parse_ft_config(Path(dir_path) / "config.ini")
+ np_dtype = np.float16 if fp16 else np.float32
+
+ def fromfile(dir_path, name, shape=None, dtype=None):
+ dtype = np_dtype if dtype is None else dtype
+ p = dir_path + "/" + name
+ if Path(p).exists():
+ t = np.fromfile(p, dtype=dtype)
+ if shape is not None:
+ t = t.reshape(shape)
+ return t
+ return None
+
+ def set_smoothquant_scale_factors(
+ module,
+ pre_scale_weight,
+ dir_path,
+ basename,
+ shape,
+ per_tok_dyn,
+ per_channel,
+ is_qkv=False,
+ rank=None,
+ ):
+ suffix = "bin"
+ if per_channel:
+ if rank is not None:
+ suffix = f"{rank}." + suffix
+ suffix = "col." + suffix
+
+ col_shape = shape if (per_channel or is_qkv) else [1, 1]
+
+ if per_tok_dyn:
+ if pre_scale_weight is not None:
+ pre_scale_weight.value = np.array([1.0], dtype=np.float32)
+ if is_qkv and not per_channel:
+ t = fromfile(
+ dir_path,
+ f"{basename}scale_w_quant_orig.{rank}.{suffix}",
+ col_shape,
+ np.float32,
+ )
+ else:
+ t = fromfile(
+ dir_path,
+ f"{basename}scale_w_quant_orig.{suffix}",
+ col_shape,
+ np.float32,
+ )
+ module.per_channel_scale.value = t
+ else:
+ t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1], np.float32)
+ pre_scale_weight.value = t
+ if is_qkv:
+ t = fromfile(
+ dir_path,
+ f"{basename}scale_y_accum_quant.{rank}.{suffix}",
+ col_shape,
+ np.float32,
+ )
+ else:
+ t = fromfile(
+ dir_path,
+ f"{basename}scale_y_accum_quant.{suffix}",
+ col_shape,
+ np.float32,
+ )
+ module.per_channel_scale.value = t
+ t = fromfile(
+ dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1], np.float32
+ )
+ module.act_scale.value = t
+
+ def set_smoother(module, dir_path, base_name, shape, rank):
+ suffix = f"{rank}.bin"
+ t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape, np.float32)
+ module.smoother.value = t
+
+ # Determine the quantization mode.
+ quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0))
+ if quant_mode.is_int8_weight_only():
+ plugin_weight_only_quant_type = torch.int8
+ elif quant_mode.is_int4_weight_only():
+ plugin_weight_only_quant_type = torch.quint4x2
+ # Do we use SmoothQuant?
+ use_smooth_quant = quant_mode.has_act_and_weight_quant()
+ # Do we use quantization per token?
+ quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling()
+ # Do we use quantization per channel?
+ quant_per_channel = quant_mode.has_per_channel_scaling()
+
+ # Do we use INT4/INT8 weight-only?
+ use_weight_only = quant_mode.is_weight_only()
+
+ # Int8 KV cache
+ use_int8_kv_cache = quant_mode.has_int8_kv_cache()
+
+ def sq_trick(x):
+ return x.view(np.float32) if use_smooth_quant else x
+
+ # Debug
+ suffix = gen_suffix(mapping.tp_rank, use_smooth_quant, quant_per_channel)
+ # The type of weights.
+ w_type = np_dtype if not use_smooth_quant else np.int8
+
+ if mapping.is_first_pp_rank():
+ tensorrt_llm_llama.vocab_embedding.weight.value = fromfile(
+ dir_path, "vocab_embedding.weight.bin", [vocab_size, n_embd]
+ )
+
+ if mapping.is_last_pp_rank():
+ tensorrt_llm_llama.ln_f.weight.value = fromfile(dir_path, "ln_f.weight.bin")
+ # share input embedding
+ lm_head_weight = fromfile(dir_path, "lm_head.weight.bin", [vocab_size, n_embd])
+
+ if vocab_size % mapping.tp_size != 0:
+ # padding
+ vocab_size_padded = tensorrt_llm_llama.lm_head.out_features * mapping.tp_size
+ pad_width = vocab_size_padded - vocab_size
+ lm_head_weight = np.pad(
+ lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0
+ )
+ if mapping.is_last_pp_rank():
+ tensorrt_llm_llama.lm_head.weight.value = np.ascontiguousarray(
+ split(lm_head_weight, mapping.tp_size, mapping.tp_rank)
+ )
+
+ layers_range = list(
+ range(
+ mapping.pp_rank * tensorrt_llm_llama.num_layers,
+ (mapping.pp_rank + 1) * tensorrt_llm_llama.num_layers,
+ 1,
+ )
+ )
+
+ for i in layers_range:
+ n_groups = n_head // n_kv_head
+ c_attn_out_dim = (
+ (3 * n_embd // mapping.tp_size)
+ if not multi_query_mode
+ else (
+ n_embd // mapping.tp_size
+ + (n_embd // n_head * n_groups) // mapping.tp_size * 2
+ )
+ )
+ idx = i - mapping.pp_rank * tensorrt_llm_llama.num_layers
+ tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = fromfile(
+ dir_path, "model.layers." + str(i) + ".input_layernorm.weight.bin"
+ )
+ t = fromfile(
+ dir_path,
+ "model.layers." + str(i) + ".attention.query_key_value.weight." + suffix,
+ [n_embd, c_attn_out_dim],
+ w_type,
+ )
+ if t is not None:
+ dst = tensorrt_llm_llama.layers[idx].attention.qkv.weight
+ if use_smooth_quant:
+ dst.value = sq_trick(np.ascontiguousarray(np.transpose(t, [1, 0])))
+ set_smoothquant_scale_factors(
+ tensorrt_llm_llama.layers[idx].attention.qkv,
+ tensorrt_llm_llama.layers[idx].input_layernorm.scale_to_int,
+ dir_path,
+ "model.layers." + str(i) + ".attention.query_key_value.",
+ [1, c_attn_out_dim],
+ quant_per_token_dyn,
+ quant_per_channel,
+ rank=mapping.tp_rank,
+ is_qkv=True,
+ )
+ elif use_weight_only:
+ (
+ processed_torch_weights,
+ torch_weight_scales,
+ ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
+ torch.tensor(t), plugin_weight_only_quant_type
+ )
+ # workaround for trt not supporting int8 inputs in plugins currently
+ dst.value = processed_torch_weights.view(dtype=torch.float32).numpy()
+ scales = tensorrt_llm_llama.layers[i].attention.qkv.per_channel_scale
+ scales.value = torch_weight_scales.numpy()
+ else:
+ dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
+
+ dst = tensorrt_llm_llama.layers[idx].attention.dense.weight
+ t = fromfile(
+ dir_path,
+ "model.layers." + str(i) + ".attention.dense.weight." + suffix,
+ [n_embd // mapping.tp_size, n_embd],
+ w_type,
+ )
+ if use_smooth_quant:
+ dst.value = sq_trick(np.ascontiguousarray(np.transpose(t, [1, 0])))
+ dense_scale = getattr(
+ tensorrt_llm_llama.layers[idx].attention,
+ "quantization_scaling_factor",
+ None,
+ )
+ set_smoothquant_scale_factors(
+ tensorrt_llm_llama.layers[idx].attention.dense,
+ dense_scale,
+ dir_path,
+ "model.layers." + str(i) + ".attention.dense.",
+ [1, n_embd],
+ quant_per_token_dyn,
+ quant_per_channel,
+ )
+ set_smoother(
+ tensorrt_llm_llama.layers[idx].attention.dense,
+ dir_path,
+ "model.layers." + str(i) + ".attention.dense",
+ [1, n_embd // mapping.tp_size],
+ mapping.tp_rank,
+ )
+ elif use_weight_only:
+ (
+ processed_torch_weights,
+ torch_weight_scales,
+ ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
+ torch.tensor(t), plugin_weight_only_quant_type
+ )
+ # workaround for trt not supporting int8 inputs in plugins currently
+ dst.value = processed_torch_weights.view(dtype=torch.float32).numpy()
+ scales = tensorrt_llm_llama.layers[i].attention.dense.per_channel_scale
+ scales.value = torch_weight_scales.numpy()
+ else:
+ dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))
+
+ dst = tensorrt_llm_llama.layers[idx].post_layernorm.weight
+ dst.value = fromfile(
+ dir_path, "model.layers." + str(i) + ".post_layernorm.weight.bin"
+ )
+
+ t = fromfile(
+ dir_path,
+ "model.layers." + str(i) + ".mlp.fc.weight." + suffix,
+ [n_embd, inter_size // mapping.tp_size],
+ w_type,
+ )
+
+ if use_smooth_quant:
+ tensorrt_llm_llama.layers[idx].mlp.fc.weight.value = sq_trick(
+ np.ascontiguousarray(np.transpose(t, [1, 0]))
+ )
+ set_smoothquant_scale_factors(
+ tensorrt_llm_llama.layers[idx].mlp.fc,
+ tensorrt_llm_llama.layers[idx].post_layernorm.scale_to_int,
+ dir_path,
+ "model.layers." + str(i) + ".mlp.fc.",
+ [1, inter_size // mapping.tp_size],
+ quant_per_token_dyn,
+ quant_per_channel,
+ rank=mapping.tp_rank,
+ )
+ elif use_weight_only:
+ dst = tensorrt_llm_llama.layers[i].mlp.fc.weight
+ (
+ processed_torch_weights,
+ torch_weight_scales,
+ ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
+ torch.tensor(t), plugin_weight_only_quant_type
+ )
+ # workaround for trt not supporting int8 inputs in plugins currently
+ dst.value = processed_torch_weights.view(dtype=torch.float32).numpy()
+ scales = tensorrt_llm_llama.layers[i].mlp.fc.per_channel_scale
+ scales.value = torch_weight_scales.numpy()
+ else:
+ tensorrt_llm_llama.layers[idx].mlp.fc.weight.value = np.ascontiguousarray(
+ np.transpose(t, [1, 0])
+ )
+
+ t = fromfile(
+ dir_path,
+ "model.layers." + str(i) + ".mlp.gate.weight." + suffix,
+ [n_embd, inter_size // mapping.tp_size],
+ w_type,
+ )
+ if use_smooth_quant:
+ tensorrt_llm_llama.layers[idx].mlp.gate.weight.value = sq_trick(
+ np.ascontiguousarray(np.transpose(t, [1, 0]))
+ )
+ set_smoothquant_scale_factors(
+ tensorrt_llm_llama.layers[idx].mlp.gate,
+ tensorrt_llm_llama.layers[idx].post_layernorm.scale_to_int,
+ dir_path,
+ "model.layers." + str(i) + ".mlp.gate.",
+ [1, inter_size // mapping.tp_size],
+ quant_per_token_dyn,
+ quant_per_channel,
+ rank=mapping.tp_rank,
+ )
+ elif use_weight_only:
+ dst = tensorrt_llm_llama.layers[i].mlp.gate.weight
+ (
+ processed_torch_weights,
+ torch_weight_scales,
+ ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
+ torch.tensor(t), plugin_weight_only_quant_type
+ )
+ # workaround for trt not supporting int8 inputs in plugins currently
+ dst.value = processed_torch_weights.view(dtype=torch.float32).numpy()
+ scales = tensorrt_llm_llama.layers[i].mlp.gate.per_channel_scale
+ scales.value = torch_weight_scales.numpy()
+ else:
+ tensorrt_llm_llama.layers[idx].mlp.gate.weight.value = np.ascontiguousarray(
+ np.transpose(t, [1, 0])
+ )
+
+ t = fromfile(
+ dir_path,
+ "model.layers." + str(i) + ".mlp.proj.weight." + suffix,
+ [inter_size // mapping.tp_size, n_embd],
+ w_type,
+ )
+ if use_smooth_quant:
+ tensorrt_llm_llama.layers[idx].mlp.proj.weight.value = sq_trick(
+ np.ascontiguousarray(np.transpose(t, [1, 0]))
+ )
+ proj_scale = getattr(
+ tensorrt_llm_llama.layers[idx].mlp, "quantization_scaling_factor", None
+ )
+ set_smoothquant_scale_factors(
+ tensorrt_llm_llama.layers[idx].mlp.proj,
+ proj_scale,
+ dir_path,
+ "model.layers." + str(i) + ".mlp.proj.",
+ [1, n_embd],
+ quant_per_token_dyn,
+ quant_per_channel,
+ )
+ set_smoother(
+ tensorrt_llm_llama.layers[idx].mlp.proj,
+ dir_path,
+ "model.layers." + str(i) + ".mlp.proj",
+ [1, inter_size // mapping.tp_size],
+ mapping.tp_rank,
+ )
+ elif use_weight_only:
+ dst = tensorrt_llm_llama.layers[i].mlp.proj.weight
+ (
+ processed_torch_weights,
+ torch_weight_scales,
+ ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
+ torch.tensor(t), plugin_weight_only_quant_type
+ )
+ # workaround for trt not supporting int8 inputs in plugins currently
+ dst.value = processed_torch_weights.view(dtype=torch.float32).numpy()
+ scales = tensorrt_llm_llama.layers[i].mlp.proj.per_channel_scale
+ scales.value = torch_weight_scales.numpy()
+ else:
+ tensorrt_llm_llama.layers[idx].mlp.proj.weight.value = np.ascontiguousarray(
+ np.transpose(t, [1, 0])
+ )
+
+ if use_int8_kv_cache:
+ t = fromfile(
+ dir_path,
+ "model.layers."
+ + str(i)
+ + ".attention.query_key_value.scale_y_quant_orig.bin",
+ [1],
+ np.float32,
+ )
+ tensorrt_llm_llama.layers[idx].attention.kv_orig_quant_scale.value = 1.0 / t
+ tensorrt_llm_llama.layers[idx].attention.kv_quant_orig_scale.value = t
+
+ tok = time.time()
+ t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
+ tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}")
+
+
+def load_from_gptq_llama(
+ tensorrt_llm_llama, quant_ckpt_path, mapping=Mapping(), dtype="float16"
+):
+ tensorrt_llm.logger.info("Loading weights from groupwise GPTQ LLaMA safetensors...")
+ tik = time.time()
+
+ if quant_ckpt_path.endswith(".safetensors"):
+ groupwise_qweight_safetensors = safe_open(
+ quant_ckpt_path, framework="pt", device=0
+ )
+ model_params = {
+ key: groupwise_qweight_safetensors.get_tensor(key)
+ for key in groupwise_qweight_safetensors.keys()
+ }
+ elif quant_ckpt_path.endswith(".pt"):
+ model_params = torch.load(quant_ckpt_path, map_location=torch.device("cpu"))
+ else:
+ assert False, "Quantized checkpoint format not supported!"
+
+ def unpack_int32_into_int8(w_packed):
+ # Unpack inputs packed in int32/float32 into uint4 and store them in int8 format
+ w_packed_int4x2 = w_packed.contiguous().view(torch.uint8)
+ w_unpacked = torch.zeros(
+ w_packed_int4x2.shape[0], w_packed_int4x2.shape[1] * 2, dtype=torch.int8
+ )
+ w_unpacked[:, ::2] = w_packed_int4x2 % 16
+ w_unpacked[:, 1::2] = w_packed_int4x2 // 16
+ return w_unpacked.contiguous()
+
+ def preprocess_groupwise_weight_params(
+ weight_name, qweight_int32=None, qzeros_int32=None, scales_fp16=None
+ ):
+ if weight_name is not None:
+ qweight_int32 = model_params[weight_name].cpu()
+ qzeros_int32 = model_params[weight_name[:-7] + "qzeros"].cpu()
+ scales_fp16 = model_params[weight_name[:-7] + "scales"].cpu()
+
+ UINT4_TO_INT4_FLAG = 1
+ GPTQ_FLAG = 1
+ packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
+ preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
+
+ qweight_unpacked_int8 = (
+ unpack_int32_into_int8(qweight_int32.T).T.contiguous() - 8
+ )
+ qweight_interleaved = preprocessor(
+ packer(qweight_unpacked_int8), torch.quint4x2
+ ).view(torch.float32)
+ # zeros = zeros * scales
+ qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32)
+ zeros_x_scales_fp16 = (
+ -qzeros_unpacked_int32 + 8 * UINT4_TO_INT4_FLAG - GPTQ_FLAG
+ ) * scales_fp16
+ zeros_x_scales_fp16 = zeros_x_scales_fp16.half()
+
+ # return processed interleaved weight, original scales and zeros * scales
+ return (
+ qweight_interleaved.contiguous(),
+ scales_fp16.contiguous(),
+ zeros_x_scales_fp16.contiguous(),
+ )
+
+ layer_ids = [extract_layer_idx(key) for key in groupwise_qweight_safetensors.keys()]
+ layer_ids = [int(layer_idx) for layer_idx in layer_ids if layer_idx is not None]
+ num_hidden_layers = max(layer_ids) + 1
+ num_kv_heads = tensorrt_llm_llama.num_kv_heads
+ mha_mode = num_kv_heads == tensorrt_llm_llama.num_heads
+ suffixs = ["qweight", "qzeros", "scales"]
+
+ layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size
+ layers_range = list(
+ range(
+ mapping.pp_rank * layers_per_pipeline_stage,
+ (mapping.pp_rank + 1) * layers_per_pipeline_stage,
+ 1,
+ )
+ )
+
+ for l in layers_range:
+ prefix = f"model.layers.{l}.self_attn."
+ split_qkv_suf = []
+
+ for suf in suffixs:
+ q_part = model_params[prefix + "q_proj." + suf].cpu()
+ k_part = model_params[prefix + "k_proj." + suf].cpu()
+ v_part = model_params[prefix + "v_proj." + suf].cpu()
+ qkv_part = torch.cat([q_part, k_part, v_part], dim=0)
+ dim = qkv_part.shape
+ qkv_part = qkv_part.reshape(3, dim[0] // 3, dim[1])
+ split_qkv = qkv_part.split(dim[1] // mapping.tp_size, dim=2)[
+ mapping.tp_rank
+ ]
+ split_qkv = torch.cat(
+ [
+ split_qkv[0, :, :].squeeze(0),
+ split_qkv[1, :, :].squeeze(0),
+ split_qkv[2, :, :].squeeze(0),
+ ],
+ dim=1,
+ )
+ split_qkv_suf.append(split_qkv)
+
+ th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
+ None, split_qkv_suf[0], split_qkv_suf[1], split_qkv_suf[2]
+ )
+
+ idx = l - mapping.pp_rank * layers_per_pipeline_stage
+ tensorrt_llm_llama.layers[idx].attention.qkv.qweight.value = th_qweight.numpy()
+ tensorrt_llm_llama.layers[idx].attention.qkv.scale.value = th_zero.numpy()
+ tensorrt_llm_llama.layers[idx].attention.qkv.zero.value = th_scale.numpy()
+
+ torch_dtype = str_dtype_to_torch(dtype)
+
+ for k, v in model_params.items():
+ if isinstance(v, list):
+ v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v]
+ else:
+ v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
+ if "model.embed_tokens.weight" in k:
+ if mapping.is_first_pp_rank():
+ tensorrt_llm_llama.vocab_embedding.weight.value = v
+ elif "model.norm.weight" in k:
+ if mapping.is_last_pp_rank():
+ tensorrt_llm_llama.ln_f.weight.value = v
+ elif "lm_head.weight" in k:
+ if mapping.is_last_pp_rank():
+ tensorrt_llm_llama.lm_head.weight.value = np.ascontiguousarray(
+ split(v, mapping.tp_size, mapping.tp_rank)
+ )
+ else:
+ layer_idx = extract_layer_idx(k)
+ if layer_idx is None:
+ continue
+ idx = int(layer_idx)
+ if idx not in layers_range:
+ continue
+ idx = idx - mapping.pp_rank * layers_per_pipeline_stage
+
+ if "input_layernorm.weight" in k:
+ tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = v
+ elif "post_attention_layernorm.weight" in k:
+ tensorrt_llm_llama.layers[idx].post_layernorm.weight.value = v
+ elif "self_attn.o_proj.qweight" in k:
+ split_v_suf = []
+ for suf in suffixs:
+ v = model_params[k[:-7] + suf].cpu()
+ split_v = v.split(v.shape[0] // mapping.tp_size, dim=0)[
+ mapping.tp_rank
+ ]
+ split_v_suf.append(split_v)
+ th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
+ None, split_v_suf[0], split_v_suf[1], split_v_suf[2]
+ )
+ tensorrt_llm_llama.layers[
+ idx
+ ].attention.dense.qweight.value = th_qweight.numpy()
+ tensorrt_llm_llama.layers[
+ idx
+ ].attention.dense.scale.value = th_zero.numpy()
+ tensorrt_llm_llama.layers[
+ idx
+ ].attention.dense.zero.value = th_scale.numpy()
+ elif "mlp.up_proj.qweight" in k:
+ split_v_suf = []
+ for suf in suffixs:
+ v = model_params[k[:-7] + suf].cpu()
+ split_v = v.split(v.shape[1] // mapping.tp_size, dim=1)[
+ mapping.tp_rank
+ ]
+ split_v_suf.append(split_v)
+ th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
+ None, split_v_suf[0], split_v_suf[1], split_v_suf[2]
+ )
+ tensorrt_llm_llama.layers[
+ idx
+ ].mlp.gate.qweight.value = th_qweight.numpy()
+ tensorrt_llm_llama.layers[idx].mlp.gate.scale.value = th_zero.numpy()
+ tensorrt_llm_llama.layers[idx].mlp.gate.zero.value = th_scale.numpy()
+ elif "mlp.down_proj.qweight" in k:
+ split_v_suf = []
+ for suf in suffixs:
+ v = model_params[k[:-7] + suf].cpu()
+ split_v = v.split(v.shape[0] // mapping.tp_size, dim=0)[
+ mapping.tp_rank
+ ]
+ split_v_suf.append(split_v)
+ th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
+ None, split_v_suf[0], split_v_suf[1], split_v_suf[2]
+ )
+ tensorrt_llm_llama.layers[
+ idx
+ ].mlp.proj.qweight.value = th_qweight.numpy()
+ tensorrt_llm_llama.layers[idx].mlp.proj.scale.value = th_zero.numpy()
+ tensorrt_llm_llama.layers[idx].mlp.proj.zero.value = th_scale.numpy()
+ elif "mlp.gate_proj.qweight" in k:
+ split_v_suf = []
+ for suf in suffixs:
+ v = model_params[k[:-7] + suf].cpu()
+ split_v = v.split(v.shape[1] // mapping.tp_size, dim=1)[
+ mapping.tp_rank
+ ]
+ split_v_suf.append(split_v)
+ th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
+ None, split_v_suf[0], split_v_suf[1], split_v_suf[2]
+ )
+ tensorrt_llm_llama.layers[idx].mlp.fc.qweight.value = th_qweight.numpy()
+ tensorrt_llm_llama.layers[idx].mlp.fc.scale.value = th_zero.numpy()
+ tensorrt_llm_llama.layers[idx].mlp.fc.zero.value = th_scale.numpy()
+
+ tok = time.time()
+ t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
+ tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}")
+ return
+
+
+def load_from_awq_llama(
+ tensorrt_llm_llama: LLaMAForCausalLM,
+ quant_ckpt_path,
+ mapping=Mapping(),
+ dtype="float16",
+):
+ tensorrt_llm.logger.info("Loading weights from groupwise AWQ LLaMA safetensors...")
+ tik = time.time()
+
+ if quant_ckpt_path.endswith(".safetensors"):
+ groupwise_qweight_safetensors = safe_open(
+ quant_ckpt_path, framework="pt", device=0
+ )
+ awq_llama = {
+ key: groupwise_qweight_safetensors.get_tensor(key)
+ for key in groupwise_qweight_safetensors.keys()
+ }
+ elif quant_ckpt_path.endswith(".pt"):
+ awq_llama = torch.load(quant_ckpt_path, map_location=torch.device("cpu"))
+ else:
+ assert False, "Quantized checkpoint format not supported!"
+
+ group_size = (
+ awq_llama["model.layers.0.self_attn.o_proj.weight"].numel()
+ // awq_llama["model.layers.0.self_attn.o_proj.weight_quantizer._amax"].numel()
+ )
+
+ awq_llama_block_names = [
+ "input_layernorm.weight",
+ "post_attention_layernorm.weight",
+ ]
+
+ tensorrt_llm_llama_block_names = [
+ "input_layernorm.weight",
+ "post_layernorm.weight",
+ ]
+
+ getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0))
+
+ packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
+ preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
+ torch_dtype = str_dtype_to_torch(dtype)
+
+ def AWQ_quantize_pack_preprocess(weight, scale):
+ scale = scale.repeat_interleave(group_size, dim=0)
+ weight = weight / scale
+ qweight_int8 = torch.clamp(torch.round(weight.cuda()).char(), -8, 7)
+ int4_weight = packer(qweight_int8.cpu())
+ int4_weight = preprocessor(int4_weight, torch.quint4x2)
+ return int4_weight.view(torch.float32).cpu().numpy()
+
+ def process_and_assign_weight(awq_llama, mPrefix, mOp, tp_dim=0):
+ weight = awq_llama[mPrefix + ".weight"].T.contiguous()
+ [k, n] = weight.shape
+ weight = weight.split(weight.shape[tp_dim] // mapping.tp_size, dim=tp_dim)[
+ mapping.tp_rank
+ ]
+ amax = (
+ awq_llama[mPrefix + ".weight_quantizer._amax"]
+ .reshape((n, int(k / group_size)))
+ .T.contiguous()
+ )
+ amax = amax.split(amax.shape[tp_dim] // mapping.tp_size, dim=tp_dim)[
+ mapping.tp_rank
+ ]
+ pre_quant_scale = awq_llama[
+ mPrefix + ".input_quantizer._pre_quant_scale"
+ ].reshape((1, k))
+ if tp_dim == 0:
+ pre_quant_scale = pre_quant_scale.split(k // mapping.tp_size, dim=1)[
+ mapping.tp_rank
+ ]
+ scale = amax / 8.0
+ mOp.qweight.value = AWQ_quantize_pack_preprocess(weight, scale)
+ mOp.scale.value = scale.to(torch_dtype).cpu().numpy()
+ mOp.pre_quant_scale.value = pre_quant_scale.to(torch_dtype).cpu().numpy()
+
+ def deSmooth(weight, pre_quant_scale):
+ [k, n] = weight.shape
+ pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1, 0).contiguous()
+ weight = weight * pre_quant_scale
+ return weight
+
+ def reSmooth(weight, pre_quant_scale):
+ [k, n] = weight.shape
+ pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1, 0).contiguous()
+ weight = weight / pre_quant_scale
+ return weight
+
+ def get_scale(weight):
+ weight = weight.T.contiguous()
+ [n, k] = weight.shape
+ weight = weight.reshape(n, int(k / group_size), group_size)
+ weight = torch.abs(weight.reshape(-1, group_size))
+ amax, idx = weight.max(1)
+ amax = amax.reshape(n, int(k / group_size)).T.contiguous()
+ return amax / 8
+
+ def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale):
+ weight = deSmooth(weight, pre_quant_scale)
+ weight = reSmooth(weight, avg_pre_quant_scale)
+ scale = get_scale(weight)
+ return weight, scale
+
+ def process_and_assign_qkv_weight(awq_llama, prefix, mOp):
+ q_weight = awq_llama[prefix + "self_attn.q_proj.weight"].T.contiguous()
+ k_weight = awq_llama[prefix + "self_attn.k_proj.weight"].T.contiguous()
+ v_weight = awq_llama[prefix + "self_attn.v_proj.weight"].T.contiguous()
+ k = q_weight.shape[0]
+
+ q_weight = q_weight.split(q_weight.shape[1] // mapping.tp_size, dim=1)[
+ mapping.tp_rank
+ ]
+ k_weight = k_weight.split(k_weight.shape[1] // mapping.tp_size, dim=1)[
+ mapping.tp_rank
+ ]
+ v_weight = v_weight.split(v_weight.shape[1] // mapping.tp_size, dim=1)[
+ mapping.tp_rank
+ ]
+
+ q_pre_quant_scale = awq_llama[
+ prefix + "self_attn.q_proj.input_quantizer._pre_quant_scale"
+ ].reshape((1, k))
+ k_pre_quant_scale = awq_llama[
+ prefix + "self_attn.k_proj.input_quantizer._pre_quant_scale"
+ ].reshape((1, k))
+ v_pre_quant_scale = awq_llama[
+ prefix + "self_attn.v_proj.input_quantizer._pre_quant_scale"
+ ].reshape((1, k))
+
+ qkv_pre_quant_scale = (
+ q_pre_quant_scale + k_pre_quant_scale + v_pre_quant_scale
+ ) / 3.0
+ q_weight, q_scale = reSmooth_and_get_scale(
+ q_weight, q_pre_quant_scale, qkv_pre_quant_scale
+ )
+ k_weight, k_scale = reSmooth_and_get_scale(
+ k_weight, k_pre_quant_scale, qkv_pre_quant_scale
+ )
+ v_weight, v_scale = reSmooth_and_get_scale(
+ v_weight, v_pre_quant_scale, qkv_pre_quant_scale
+ )
+
+ qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1)
+ qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1)
+
+ mOp.pre_quant_scale.value = qkv_pre_quant_scale.to(torch_dtype).cpu().numpy()
+ mOp.qweight.value = AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale)
+ mOp.scale.value = qkv_scale.to(torch_dtype).cpu().numpy()
+
+ # Check if we need to pad vocab
+ v = awq_llama.get("model.embed_tokens.weight")
+ [vocab_size, k] = v.shape
+ pad_vocab = False
+ pad_vocab_size = vocab_size
+ if vocab_size % 64 != 0:
+ pad_vocab = True
+ pad_vocab_size = int((vocab_size + 63) / 64) * 64
+ if pad_vocab:
+ new_v = torch.zeros([pad_vocab_size, k])
+ new_v[:vocab_size, :] = v
+ v = new_v
+ if mapping.is_first_pp_rank():
+ tensorrt_llm_llama.vocab_embedding.weight.value = (
+ v.to(torch_dtype).cpu().numpy()
+ )
+
+ layer_ids = [extract_layer_idx(key) for key in awq_llama.keys()]
+ layer_ids = [int(layer_idx) for layer_idx in layer_ids if layer_idx is not None]
+
+ num_hidden_layers = max(layer_ids) + 1
+ layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size
+ layers_range = list(
+ range(
+ mapping.pp_rank * layers_per_pipeline_stage,
+ (mapping.pp_rank + 1) * layers_per_pipeline_stage,
+ 1,
+ )
+ )
+
+ for layer_idx in layers_range:
+ prefix = "model.layers." + str(layer_idx) + "."
+ tensorrt_llm.logger.info(f"Process weights in layer: {layer_idx}")
+ for idx, awq_attr in enumerate(awq_llama_block_names):
+ v = awq_llama[prefix + awq_attr]
+ layer = attrgetter(tensorrt_llm_llama_block_names[idx])(
+ tensorrt_llm_llama.layers[layer_idx]
+ )
+ setattr(layer, "value", v.to(torch_dtype).cpu().numpy())
+
+ # Attention QKV Linear
+ # concatenate the Q, K, V layers weights.
+ process_and_assign_qkv_weight(
+ awq_llama, prefix, tensorrt_llm_llama.layers[layer_idx].attention.qkv
+ )
+
+ # Attention Dense (out_proj) Linear
+ mPrefix = prefix + "self_attn.o_proj"
+ mOp = tensorrt_llm_llama.layers[layer_idx].attention.dense
+ process_and_assign_weight(awq_llama, mPrefix, mOp, 0)
+
+ # MLP up_proj (mlp.gate) Linear
+ mPrefix = prefix + "mlp.up_proj"
+ mOp = tensorrt_llm_llama.layers[layer_idx].mlp.gate
+ process_and_assign_weight(awq_llama, mPrefix, mOp, 1)
+
+ # MLP down_proj (mlp.proj) Linear
+ mPrefix = prefix + "mlp.down_proj"
+ mOp = tensorrt_llm_llama.layers[layer_idx].mlp.proj
+ process_and_assign_weight(awq_llama, mPrefix, mOp, 0)
+
+ # MLP gate_proj (mlp.fc) Linear
+ mPrefix = prefix + "mlp.gate_proj"
+ mOp = tensorrt_llm_llama.layers[layer_idx].mlp.fc
+ process_and_assign_weight(awq_llama, mPrefix, mOp, 1)
+
+ v = awq_llama["model.norm.weight"]
+ if mapping.is_last_pp_rank():
+ tensorrt_llm_llama.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()
+
+ # lm_head
+ if pad_vocab:
+ weight = awq_llama["lm_head.weight"]
+ [vocab_size, k] = weight.shape
+ new_weight = torch.zeros([pad_vocab_size, k])
+ new_weight[:vocab_size, :] = weight
+ new_weight = new_weight.T.contiguous()
+ amax = awq_llama["lm_head.weight_quantizer._amax"].reshape(
+ [vocab_size, k // group_size]
+ )
+ new_amax = torch.ones([pad_vocab_size, k // group_size])
+ new_amax[:vocab_size, :] = amax
+ new_amax = new_amax.T.contiguous()
+ new_scale = new_amax / 8
+ tensorrt_llm_llama.lm_head.qweight.value = AWQ_quantize_pack_preprocess(
+ new_weight, new_scale
+ )
+ tensorrt_llm_llama.lm_head.scale.value = new_scale.to(torch_dtype).cpu().numpy()
+ tensorrt_llm_llama.lm_head.pre_quant_scale.value = (
+ awq_llama["lm_head.input_quantizer._pre_quant_scale"]
+ .to(torch_dtype)
+ .cpu()
+ .numpy()
+ )
+ else:
+ mPrefix = "lm_head"
+ mOp = tensorrt_llm_llama.lm_head
+ if mapping.is_last_pp_rank():
+ process_and_assign_weight(awq_llama, mPrefix, mOp, 1)
+
+ tok = time.time()
+ t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
+ tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}")
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/gptnext b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/gptnext
new file mode 120000
index 00000000..056bf100
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/gptnext
@@ -0,0 +1 @@
+llama
\ No newline at end of file
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/1/.gitkeep b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/1/.gitkeep
new file mode 100644
index 00000000..e69de29b
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/config.pbtxt b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/config.pbtxt
new file mode 100755
index 00000000..cbd087ce
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/ensemble/config.pbtxt
@@ -0,0 +1,228 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of NVIDIA CORPORATION nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+name: "ensemble"
+platform: "ensemble"
+max_batch_size: 128
+input [
+ {
+ name: "text_input"
+ data_type: TYPE_STRING
+ dims: [ -1 ]
+ },
+ {
+ name: "max_tokens"
+ data_type: TYPE_UINT32
+ dims: [ -1 ]
+ },
+ {
+ name: "end_id"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "pad_id"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "top_k"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "top_p"
+ data_type: TYPE_FP32
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "temperature"
+ data_type: TYPE_FP32
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "length_penalty"
+ data_type: TYPE_FP32
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "repetition_penalty"
+ data_type: TYPE_FP32
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "min_length"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "presence_penalty"
+ data_type: TYPE_FP32
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "random_seed"
+ data_type: TYPE_UINT64
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "beam_width"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "stream"
+ data_type: TYPE_BOOL
+ dims: [ 1 ]
+ optional: true
+ }
+]
+output [
+ {
+ name: "text_output"
+ data_type: TYPE_STRING
+ dims: [ -1, -1 ]
+ }
+]
+ensemble_scheduling {
+ step [
+ {
+ model_name: "preprocessing"
+ model_version: -1
+ input_map {
+ key: "QUERY"
+ value: "text_input"
+ }
+ input_map {
+ key: "REQUEST_OUTPUT_LEN"
+ value: "max_tokens"
+ }
+ output_map {
+ key: "REQUEST_INPUT_LEN"
+ value: "_REQUEST_INPUT_LEN"
+ }
+ output_map {
+ key: "INPUT_ID"
+ value: "_INPUT_ID"
+ }
+ output_map {
+ key: "REQUEST_OUTPUT_LEN"
+ value: "_REQUEST_OUTPUT_LEN"
+ }
+ },
+ {
+ model_name: "tensorrt_llm"
+ model_version: -1
+ input_map {
+ key: "input_ids"
+ value: "_INPUT_ID"
+ }
+ input_map {
+ key: "input_lengths"
+ value: "_REQUEST_INPUT_LEN"
+ }
+ input_map {
+ key: "request_output_len"
+ value: "_REQUEST_OUTPUT_LEN"
+ }
+ input_map {
+ key: "end_id"
+ value: "end_id"
+ }
+ input_map {
+ key: "pad_id"
+ value: "pad_id"
+ }
+ input_map {
+ key: "runtime_top_k"
+ value: "top_k"
+ }
+ input_map {
+ key: "runtime_top_p"
+ value: "top_p"
+ }
+ input_map {
+ key: "temperature"
+ value: "temperature"
+ }
+ input_map {
+ key: "len_penalty"
+ value: "length_penalty"
+ }
+ input_map {
+ key: "repetition_penalty"
+ value: "repetition_penalty"
+ }
+ input_map {
+ key: "min_length"
+ value: "min_length"
+ }
+ input_map {
+ key: "presence_penalty"
+ value: "presence_penalty"
+ }
+ input_map {
+ key: "random_seed"
+ value: "random_seed"
+ }
+ input_map {
+ key: "beam_width"
+ value: "beam_width"
+ }
+ input_map {
+ key: "streaming"
+ value: "stream"
+ }
+ output_map {
+ key: "output_ids"
+ value: "_TOKENS_BATCH"
+ }
+ },
+ {
+ model_name: "postprocessing"
+ model_version: -1
+ input_map {
+ key: "TOKENS_BATCH"
+ value: "_TOKENS_BATCH"
+ }
+ output_map {
+ key: "OUTPUT"
+ value: "text_output"
+ }
+ }
+ ]
+}
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/1/model.py b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/1/model.py
new file mode 100755
index 00000000..0e563c96
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/1/model.py
@@ -0,0 +1,173 @@
+# -*- coding: utf-8 -*-
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+
+import numpy as np
+import triton_python_backend_utils as pb_utils
+from transformers import LlamaTokenizer
+
+TOKENIZER_DIR = os.environ.get("TOKENIZER_DIR", "/model")
+
+SPACE_CHAR = 9601
+NEWLINE_CHAR = 60
+STOP_TOKEN = 2
+
+
+class TritonPythonModel:
+ """Your Python model must use the same class name. Every Python model
+ that is created must have "TritonPythonModel" as the class name.
+ """
+
+ def initialize(self, args):
+ """`initialize` is called only once when the model is being loaded.
+ Implementing `initialize` function is optional. This function allows
+ the model to initialize any state associated with this model.
+ Parameters
+ ----------
+ args : dict
+ Both keys and values are strings. The dictionary keys and values are:
+ * model_config: A JSON string containing the model configuration
+ * model_instance_kind: A string containing model instance kind
+ * model_instance_device_id: A string containing model instance device ID
+ * model_repository: Model repository path
+ * model_version: Model version
+ * model_name: Model name
+ """
+ # Parse model configs
+ self.model_config = model_config = json.loads(args["model_config"])
+
+ # Parse model output configs
+ output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT")
+
+ # Convert Triton types to numpy types
+ self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
+
+ self.tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_DIR, legacy=False)
+ vocab = self.tokenizer.convert_ids_to_tokens(
+ list(range(self.tokenizer.vocab_size))
+ )
+
+ def execute(self, requests):
+ """`execute` must be implemented in every Python model. `execute`
+ function receives a list of pb_utils.InferenceRequest as the only
+ argument. This function is called when an inference is requested
+ for this model. Depending on the batching configuration (e.g. Dynamic
+ Batching) used, `requests` may contain multiple requests. Every
+ Python model, must create one pb_utils.InferenceResponse for every
+ pb_utils.InferenceRequest in `requests`. If there is an error, you can
+ set the error argument when creating a pb_utils.InferenceResponse.
+ Parameters
+ ----------
+ requests : list
+ A list of pb_utils.InferenceRequest
+ Returns
+ -------
+ list
+ A list of pb_utils.InferenceResponse. The length of this list must
+ be the same as `requests`
+ """
+
+ responses = []
+
+ # Every Python backend must iterate over everyone of the requests
+ # and create a pb_utils.InferenceResponse for each of them.
+ for request in requests:
+ # Get input tensors
+ tokens_batch = pb_utils.get_input_tensor_by_name(
+ request, "TOKENS_BATCH"
+ ).as_numpy()
+
+ # Reshape Input
+ # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]])
+ # tokens_batch = tokens_batch.T
+
+ # Postprocessing output data.
+ outputs = self._postprocessing(tokens_batch)
+
+ # Create output tensors. You need pb_utils.Tensor
+ # objects to create pb_utils.InferenceResponse.
+ output_tensor = pb_utils.Tensor(
+ "OUTPUT", np.array(outputs).astype(self.output_dtype)
+ )
+
+ # Create InferenceResponse. You can set an error here in case
+ # there was a problem with handling this inference request.
+ # Below is an example of how you can set errors in inference
+ # response:
+ #
+ # pb_utils.InferenceResponse(
+ # output_tensors=..., TritonError("An error occurred"))
+ inference_response = pb_utils.InferenceResponse(
+ output_tensors=[output_tensor]
+ )
+ responses.append(inference_response)
+
+ # You should return a list of pb_utils.InferenceResponse. Length
+ # of this list must match the length of `requests` list.
+ return responses
+
+ def finalize(self):
+ """`finalize` is called only once when the model is being unloaded.
+ `Implementing `finalize` function is optional. This function allows
+ the model to perform any necessary clean ups before exit.
+ """
+ pb_utils.Logger.log("Finalizing the Post-Processing Model.")
+
+ def _id_to_token(self, token_id):
+ # handle special tokens (end of string, unknown, etc)
+ try:
+ special_token_index = self.tokenizer.all_special_ids.index(token_id)
+ return self.tokenizer.all_special_tokens[special_token_index]
+ except ValueError:
+ pass
+
+ # handle typical tokens
+ tokens = self.tokenizer.convert_ids_to_tokens(token_id)
+ if ord(tokens[0]) == SPACE_CHAR:
+ return f" {tokens[1:]}"
+ if ord(tokens[0]) == NEWLINE_CHAR:
+ return "\n"
+ return tokens
+
+ def _postprocessing(self, tokens_batch):
+ tokens_batch = tokens_batch.tolist()
+ return [
+ self._id_to_token(token_id)
+ for beam_tokens in tokens_batch
+ for token_ids in beam_tokens
+ for token_id in token_ids
+ ]
+
+ # for beam_tokens in tokens_batch:
+ # for token_ids in beam_tokens:
+ # for token_id in token_ids:
+ # # handle special tokens (end of string, unknown, etc)
+ # special_token = self.tokenizer.added_tokens_decoder.get(token_id)
+ # if special_token:
+ # tokens = special_token.content
+
+ # # handle typical tokens
+ # else:
+ # tokens = self.tokenizer.convert_ids_to_tokens(token_id)
+ # if ord(tokens[0]) == SPACE_CHAR:
+ # tokens = f" {tokens[1:]}"
+ # elif ord(tokens[0]) == NEWLINE_CHAR:
+ # tokens = "\n"
+
+ # outputs.append(tokens)
+ # return outputs
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/config.pbtxt b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/config.pbtxt
new file mode 100755
index 00000000..3c3ea10d
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/postprocessing/config.pbtxt
@@ -0,0 +1,50 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of NVIDIA CORPORATION nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+name: "postprocessing"
+backend: "python"
+max_batch_size: 128
+input [
+ {
+ name: "TOKENS_BATCH"
+ data_type: TYPE_INT32
+ dims: [ -1, -1 ]
+ }
+]
+output [
+ {
+ name: "OUTPUT"
+ data_type: TYPE_STRING
+ dims: [ -1, -1 ]
+ }
+]
+
+instance_group [
+ {
+ count: 1
+ kind: KIND_CPU
+ }
+]
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/1/model.py b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/1/model.py
new file mode 100644
index 00000000..44e8b9c4
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/1/model.py
@@ -0,0 +1,244 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of NVIDIA CORPORATION nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+
+import csv
+import json
+import os
+
+import numpy as np
+import torch
+import triton_python_backend_utils as pb_utils
+from torch.nn.utils.rnn import pad_sequence
+from transformers import LlamaTokenizer
+
+TOKENIZER_DIR = os.environ.get("TOKENIZER_DIR", "/model")
+
+END_ID = 2
+
+# SYSTEM_PROMPT = (
+# """You are a helpful, respectful and honest assistant."""
+# """Always answer as helpfully as possible, while being safe."""
+# """Please ensure that your responses are positive in nature."""
+# )
+
+# LLAMA_PROMPT_TEMPLATE = (
+# "[INST] <>"
+# "{system_prompt}"
+# "<>"
+# "[/INST] {context} [INST] {question} [/INST]"
+# )
+
+
+class TritonPythonModel:
+ """Your Python model must use the same class name. Every Python model
+ that is created must have "TritonPythonModel" as the class name.
+ """
+
+ def initialize(self, args):
+ """`initialize` is called only once when the model is being loaded.
+ Implementing `initialize` function is optional. This function allows
+ the model to initialize any state associated with this model.
+ Parameters
+ ----------
+ args : dict
+ Both keys and values are strings. The dictionary keys and values are:
+ * model_config: A JSON string containing the model configuration
+ * model_instance_kind: A string containing model instance kind
+ * model_instance_device_id: A string containing model instance device ID
+ * model_repository: Model repository path
+ * model_version: Model version
+ * model_name: Model name
+ """
+ # Parse model configs
+ self.model_config = model_config = json.loads(args["model_config"])
+
+ # Parse model output configs and convert Triton types to numpy types
+ input_names = ["INPUT_ID", "REQUEST_INPUT_LEN"]
+ for input_name in input_names:
+ setattr(
+ self,
+ input_name.lower() + "_dtype",
+ pb_utils.triton_string_to_numpy(
+ pb_utils.get_output_config_by_name(model_config, input_name)[
+ "data_type"
+ ]
+ ),
+ )
+
+ self.encoder = LlamaTokenizer.from_pretrained(TOKENIZER_DIR, legacy=False)
+
+ def execute(self, requests):
+ """`execute` must be implemented in every Python model. `execute`
+ function receives a list of pb_utils.InferenceRequest as the only
+ argument. This function is called when an inference is requested
+ for this model. Depending on the batching configuration (e.g. Dynamic
+ Batching) used, `requests` may contain multiple requests. Every
+ Python model, must create one pb_utils.InferenceResponse for every
+ pb_utils.InferenceRequest in `requests`. If there is an error, you can
+ set the error argument when creating a pb_utils.InferenceResponse.
+ Parameters
+ ----------
+ requests : list
+ A list of pb_utils.InferenceRequest
+ Returns
+ -------
+ list
+ A list of pb_utils.InferenceResponse. The length of this list must
+ be the same as `requests`
+ """
+
+ responses = []
+
+ # Every Python backend must iterate over everyone of the requests
+ # and create a pb_utils.InferenceResponse for each of them.
+ for request in requests:
+ # Get input tensors
+ query = pb_utils.get_input_tensor_by_name(request, "QUERY").as_numpy()
+ request_output_len = pb_utils.get_input_tensor_by_name(
+ request, "REQUEST_OUTPUT_LEN"
+ ).as_numpy()
+
+ input_id, request_input_len = self._create_request(query)
+
+ # Create output tensors. You need pb_utils.Tensor
+ # objects to create pb_utils.InferenceResponse.
+ input_id_tensor = pb_utils.Tensor(
+ "INPUT_ID", np.array(input_id).astype(self.input_id_dtype)
+ )
+ request_input_len_tensor = pb_utils.Tensor(
+ "REQUEST_INPUT_LEN",
+ np.array(request_input_len).astype(self.request_input_len_dtype),
+ )
+ request_output_len_tensor = pb_utils.Tensor(
+ "REQUEST_OUTPUT_LEN", request_output_len
+ )
+
+ # Create InferenceResponse. You can set an error here in case
+ # there was a problem with handling this inference request.
+ # Below is an example of how you can set errors in inference
+ # response:
+ #
+ # pb_utils.InferenceResponse(
+ # output_tensors=..., TritonError("An error occurred"))
+ inference_response = pb_utils.InferenceResponse(
+ output_tensors=[
+ input_id_tensor,
+ request_input_len_tensor,
+ request_output_len_tensor,
+ ]
+ )
+ responses.append(inference_response)
+
+ # You should return a list of pb_utils.InferenceResponse. Length
+ # of this list must match the length of `requests` list.
+ return responses
+
+ def finalize(self):
+ """`finalize` is called only once when the model is being unloaded.
+ Implementing `finalize` function is optional. This function allows
+ the model to perform any necessary clean ups before exit.
+ """
+ pb_utils.Logger.log("Finalizing the Pre-Processing Model.")
+
+ def _create_request(self, prompts):
+ """
+ prompts : batch string (2D numpy array)
+ """
+
+ start_ids = [
+ torch.IntTensor(self.encoder.encode(prompt[0].decode()))
+ for prompt in prompts
+ ]
+
+ start_lengths = torch.IntTensor([[len(ids)] for ids in start_ids])
+
+ start_ids = pad_sequence(start_ids, batch_first=True, padding_value=END_ID)
+
+ return start_ids, start_lengths
+
+ def _create_word_list(self, word_dict):
+ flat_ids = []
+ offsets = []
+ for word_dict_item in word_dict:
+ item_flat_ids = []
+ item_offsets = []
+
+ words = list(csv.reader([word_dict_item[0].decode()]))[0]
+ for word in words:
+ ids = self._encode(word)
+
+ if len(ids) == 0:
+ continue
+
+ item_flat_ids += ids
+ item_offsets.append(len(ids))
+
+ flat_ids.append(np.array(item_flat_ids))
+ offsets.append(np.cumsum(np.array(item_offsets)))
+
+ pad_to = max(1, max(len(ids) for ids in flat_ids))
+
+ for i, (ids, offs) in enumerate(zip(flat_ids, offsets)):
+ flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0)
+ offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1)
+
+ return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2))
+
+ def to_word_list_format(self, word_dict):
+ flat_ids = []
+ offsets = []
+ for word_dict_item in word_dict:
+ item_flat_ids = []
+ item_offsets = []
+
+ if isinstance(word_dict_item[0], bytes):
+ word_dict_item = [word_dict_item[0].decode()]
+
+ words = list(csv.reader(word_dict_item))[0]
+ for word in words:
+ ids = self.encoder.encode(word)
+
+ if len(ids) == 0:
+ continue
+
+ item_flat_ids += ids
+ item_offsets.append(len(ids))
+
+ flat_ids.append(np.array(item_flat_ids))
+ offsets.append(np.cumsum(np.array(item_offsets)))
+
+ pad_to = max(1, max(len(ids) for ids in flat_ids))
+
+ for i, (ids, offs) in enumerate(zip(flat_ids, offsets)):
+ flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0)
+ offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1)
+
+ return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2))
+
+ def _encode(self, sentence):
+ sentence = sentence.decode() if isinstance(sentence, bytes) else sentence
+ return self.encoder.encode(sentence)
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/config.pbtxt b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/config.pbtxt
new file mode 100644
index 00000000..d2e3029a
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/preprocessing/config.pbtxt
@@ -0,0 +1,65 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of NVIDIA CORPORATION nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+name: "preprocessing"
+backend: "python"
+max_batch_size: 128
+input [
+ {
+ name: "QUERY"
+ data_type: TYPE_STRING
+ dims: [ -1 ]
+ },
+ {
+ name: "REQUEST_OUTPUT_LEN"
+ data_type: TYPE_UINT32
+ dims: [ -1 ]
+ }
+]
+output [
+ {
+ name: "INPUT_ID"
+ data_type: TYPE_INT32
+ dims: [ -1 ]
+ },
+ {
+ name: "REQUEST_INPUT_LEN"
+ data_type: TYPE_INT32
+ dims: [ 1 ]
+ },
+ {
+ name: "REQUEST_OUTPUT_LEN"
+ data_type: TYPE_UINT32
+ dims: [ -1 ]
+ }
+]
+
+instance_group [
+ {
+ count: 1
+ kind: KIND_CPU
+ }
+]
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/1/.gitkeep b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/1/.gitkeep
new file mode 100644
index 00000000..e69de29b
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/config.pbtxt.j2 b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/config.pbtxt.j2
new file mode 100644
index 00000000..4b719b04
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/ensemble_models/llama/tensorrt_llm/config.pbtxt.j2
@@ -0,0 +1,208 @@
+# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions
+# are met:
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+# * Neither the name of NVIDIA CORPORATION nor the names of its
+# contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
+# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
+# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+name: "tensorrt_llm"
+backend: "tensorrtllm"
+max_batch_size: 128
+
+model_transaction_policy {
+ decoupled: {{ decoupled_mode }}
+}
+
+input [
+ {
+ name: "input_ids"
+ data_type: TYPE_INT32
+ dims: [ -1 ]
+ },
+ {
+ name: "input_lengths"
+ data_type: TYPE_INT32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ },
+ {
+ name: "request_output_len"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ },
+ {
+ name: "end_id"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "pad_id"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "beam_width"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "temperature"
+ data_type: TYPE_FP32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "runtime_top_k"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "runtime_top_p"
+ data_type: TYPE_FP32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "len_penalty"
+ data_type: TYPE_FP32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "repetition_penalty"
+ data_type: TYPE_FP32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "min_length"
+ data_type: TYPE_UINT32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "presence_penalty"
+ data_type: TYPE_FP32
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "random_seed"
+ data_type: TYPE_UINT64
+ dims: [ 1 ]
+ reshape: { shape: [ ] }
+ optional: true
+ },
+ {
+ name: "stop"
+ data_type: TYPE_BOOL
+ dims: [ 1 ]
+ optional: true
+ },
+ {
+ name: "streaming"
+ data_type: TYPE_BOOL
+ dims: [ 1 ]
+ optional: true
+ }
+]
+output [
+ {
+ name: "output_ids"
+ data_type: TYPE_INT32
+ dims: [ -1, -1 ]
+ }
+]
+instance_group [
+ {
+ count: 1
+ kind : KIND_CPU
+ }
+]
+parameters: {
+ key: "max_beam_width"
+ value: {
+ string_value: "1"
+ }
+}
+parameters: {
+ key: "FORCE_CPU_ONLY_INPUT_TENSORS"
+ value: {
+ string_value: "no"
+ }
+}
+parameters: {
+ key: "gpt_model_type"
+ value: {
+ string_value: "{{ gpt_model_type }}"
+ }
+}
+parameters: {
+ key: "gpt_model_path"
+ value: {
+ string_value: "{{ engine_dir }}"
+ }
+}
+parameters: {
+ key: "max_tokens_in_paged_kv_cache"
+ value: {
+ string_value: ""
+ }
+}
+parameters: {
+ key: "batch_scheduler_policy"
+ value: {
+ string_value: "guaranteed_completion"
+ }
+}
+parameters: {
+ key: "kv_cache_free_gpu_mem_fraction"
+ value: {
+ string_value: ".75"
+ }
+}
+parameters: {
+ key: "max_num_sequences"
+ value: {
+ string_value: ""
+ }
+}
+parameters: {
+ key: "enable_trt_overlap"
+ value: {
+ string_value: ""
+ }
+}
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/__init__.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/__init__.py
new file mode 100644
index 00000000..c5141a38
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/__init__.py
@@ -0,0 +1,129 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Model-Server converts LLMs to TensorRT engines and hosts them with Triton."""
+import argparse
+import logging
+import os
+
+from .conversion import ConversionOptions, convert
+from .errors import ModelServerException
+from .model import Model
+from .server import ModelServer
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def _mk_symlink(source: str, dest: str) -> None:
+ """Safely create a symbolic link."""
+ if not os.path.exists(dest):
+ os.symlink(source, dest)
+ _LOGGER.debug("Creating symlink from %s to %s", source, dest)
+
+
+def _azureml(source_directory: str) -> None:
+ """Make accomodations for AzureML."""
+ _LOGGER.info("Detected running on AzureML.")
+ _LOGGER.debug("AZUREML_MODEL_DIR Variable is %s", source_directory)
+ destination_directory = "/model"
+ source_directory = os.path.abspath(source_directory)
+ _LOGGER.debug("Azure Model Directory is now %s", source_directory)
+ _LOGGER.debug(
+ "Azure model directory contents: %s", repr(os.listdir(source_directory))
+ )
+
+ # find the direct path to the model weights
+ # $AZUREML_MODEL_DIR/model_name
+ try:
+ source_directory = os.path.join(
+ source_directory, os.listdir(source_directory)[0]
+ )
+ _LOGGER.debug("Azure Model Directory is now %s", source_directory)
+ _LOGGER.debug(
+ "Azure model directory contents: %s", repr(os.listdir(source_directory))
+ )
+ except IndexError:
+ # pylint: disable-next=raise-missing-from
+ raise ModelServerException("AzureML folder structure is not recognized.")
+
+ # create links for the model files to the /models directory
+ for root, dirs, files in os.walk(source_directory):
+ root_destination = root.replace(source_directory, destination_directory)
+ for fi in files:
+ _mk_symlink(os.path.join(root, fi), os.path.join(root_destination, fi))
+ for di in dirs:
+ dest = os.path.join(root_destination, di)
+ os.makedirs(dest, exist_ok=True)
+ _LOGGER.debug("Creating directory %s", dest)
+
+
+def _should_convert(args: argparse.Namespace, model: "Model") -> bool:
+ """Determine if the conversion step should run."""
+ if args.force_conversion:
+ return True
+
+ if args.no_conversion:
+ return False
+
+ return model.conversion_is_needed()
+
+
+def main(args: argparse.Namespace) -> int:
+ """Execute the model server."""
+ # make accomidations for various ML platforms
+ azureml_model_dir = os.environ.get("AZUREML_MODEL_DIR")
+ if azureml_model_dir:
+ _azureml(azureml_model_dir)
+
+ # load the model directory
+ _LOGGER.info("Reading the model directory.")
+ model = Model(model_type=args.type, world_size=args.world_size)
+
+ # calculate the default parallism parameters
+ if not args.tensor_parallelism:
+ args.tensor_parallelism = max(
+ int(model.world_size / args.pipeline_parallelism), 1
+ )
+ if args.pipeline_parallelism * args.tensor_parallelism != model.world_size:
+ raise ModelServerException(
+ "Tensor Parallelism * Pipeline Parallelism must be equal to World Size"
+ )
+
+ conversion_opts = ConversionOptions(
+ max_input_length=args.max_input_length,
+ max_output_length=args.max_output_length,
+ tensor_parallelism=args.tensor_parallelism,
+ pipline_parallelism=args.pipeline_parallelism,
+ )
+
+ # print discovered model parameters
+ _LOGGER.info("Model file format: %s", model.format.name)
+ _LOGGER.info("World Size: %d", model.world_size)
+ _LOGGER.info("Compute Capability: %s", model.compute_cap)
+
+ # convert model
+ if _should_convert(args, model):
+ _LOGGER.info("Starting TensorRT Conversion.")
+ convert(model, conversion_opts)
+ else:
+ _LOGGER.info("TensorRT Conversion not required. Skipping.")
+
+ # host model
+ if not args.no_hosting:
+ _LOGGER.info("Starting Triton Inference Server.")
+ inference_server = ModelServer(model, args.http)
+ return inference_server.run()
+
+ return 0
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/__main__.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/__main__.py
new file mode 100644
index 00000000..b64791d7
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/__main__.py
@@ -0,0 +1,196 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Main entrypoint for the model-server application."""
+import argparse
+import logging
+import os
+import sys
+
+from . import main
+from .errors import ModelServerException
+from .model import ModelTypes
+
+TERMINATION_LOG = "/dev/termination-log"
+
+_LOG_FMT = f"[{os.getpid()}] %(asctime)15s [%(levelname)7s] - %(name)s - %(message)s"
+_LOG_DATE_FMT = "%b %d %H:%M:%S"
+_LOGGER = logging.getLogger("main")
+
+
+def parse_args() -> argparse.Namespace:
+ """Parse the comamnd line arguments."""
+ parser = argparse.ArgumentParser(
+ prog="model-server",
+ description="Ingest models and host them with NVIDIA TensorRT LLM",
+ )
+
+ # options
+ parser.add_argument(
+ "-w",
+ "--world-size",
+ default=None,
+ type=int,
+ help="The number of GPUs to shard the model across. "
+ + "By default, this value will be equal to the number of available GPUs.",
+ )
+ parser.add_argument(
+ "--force-conversion",
+ action="store_true",
+ help="When this flag is set, the TensorRT engine conversion will occur, "
+ + "even if a valid engine is in the cache.",
+ )
+ parser.add_argument(
+ "--no-conversion",
+ action="store_true",
+ help="Skip the conversion. If no engine is available in the cache, an error will be raised.",
+ )
+ parser.add_argument(
+ "--no-hosting",
+ action="store_true",
+ help="Do not start the Triton Inference Server. Only convert the model then exit.",
+ )
+ parser.add_argument(
+ "-v",
+ "--verbose",
+ action="count",
+ default=1,
+ help="increase output verbosity",
+ )
+ parser.add_argument(
+ "-q",
+ "--quiet",
+ action="count",
+ default=0,
+ help="decrease output verbosity",
+ )
+
+ # builder customization
+ parser.add_argument(
+ "--max-input-length",
+ type=int,
+ default=3000,
+ help="maximum number of input tokens",
+ )
+ parser.add_argument(
+ "--max-output-length",
+ type=int,
+ default=512,
+ help="maximum number of output tokens",
+ )
+ parser.add_argument(
+ "--tensor-parallelism",
+ type=int,
+ default=None,
+ help="number of tensor parallelism divisions (default: world_size/pipeline_parallelism)",
+ )
+ parser.add_argument(
+ "--pipeline-parallelism",
+ type=int,
+ default=1,
+ help="number of pipeline parallism divisions (default: 1)",
+ )
+
+ # server customization
+ parser.add_argument(
+ "--http",
+ action="store_true",
+ help="change the api server to http instead of grpc (note: this will disable token streaming)",
+ )
+
+ # positional arguments
+ supported_model_types = [e.name.lower().replace("_", "-") for e in ModelTypes]
+ parser.add_argument(
+ "type",
+ metavar="TYPE",
+ choices=supported_model_types,
+ type=str.lower,
+ help=f"{supported_model_types} The type of model to process.",
+ )
+
+ args = parser.parse_args()
+
+ if args.force_conversion and args.no_conversion:
+ parser.error("--force_conversion and --no-conversion are mutually exclusive.")
+
+ return args
+
+
+def _bootstrap_logging(verbosity: int = 0) -> None:
+ """Configure Python's logger according to the given verbosity level.
+
+ :param verbosity: The desired verbosity level. Must be one of 0, 1, or 2.
+ :type verbosity: typing.Literal[0, 1, 2]
+ """
+ # determine log level
+ verbosity = min(2, max(0, verbosity)) # limit verbosity to 0-2
+ log_level = [logging.WARN, logging.INFO, logging.DEBUG][verbosity]
+
+ # configure python's logger
+ logging.basicConfig(format=_LOG_FMT, datefmt=_LOG_DATE_FMT, level=log_level)
+ # update existing loggers
+ _LOGGER.setLevel(log_level)
+ # pylint: disable-next=no-member; false positive
+ for logger_name in logging.root.manager.loggerDict:
+ logger = logging.getLogger(logger_name)
+ for handler in logger.handlers:
+ handler.setFormatter(logging.Formatter(fmt=_LOG_FMT, datefmt=_LOG_DATE_FMT))
+
+
+def _k8s_error_handler(err: Exception) -> None:
+ """When running in Kubernetes, write errors to the termination log."""
+ with open(TERMINATION_LOG, "w", encoding="UTF-8") as term_log:
+ # recursively write nested exceptions
+ def _write_errors_to_term_log(e: BaseException) -> None:
+ term_log.write(f"{type(e)}: {e}\n")
+ if e.__cause__:
+ _write_errors_to_term_log(e.__cause__)
+
+ _write_errors_to_term_log(err)
+
+
+def _error_handler(err: Exception) -> int:
+ """Catch and handle exceptions from the applicaiton."""
+ # keybaord interrupts are fine
+ if isinstance(err, KeyboardInterrupt):
+ return 0
+
+ # on k8s, write errors to log file
+ if os.path.isfile(TERMINATION_LOG):
+ _k8s_error_handler(err)
+
+ # raise uncaught errors
+ if not isinstance(err, ModelServerException):
+ raise err
+
+ # gracefully handle caught errors
+ _LOGGER.error(str(err))
+
+ # if there is a nested error, raise it
+ if err.__cause__:
+ raise err.__cause__
+
+ # we decided to quite gracefully
+ return 1
+
+
+if __name__ == "__main__":
+ try:
+ _ARGS = parse_args()
+ _bootstrap_logging(_ARGS.verbose - _ARGS.quiet)
+ sys.exit(main(_ARGS))
+ # pylint: disable-next=broad-exception-caught; Error handling based on type is done in the handler
+ except Exception as _ERR:
+ sys.exit(_error_handler(_ERR))
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/__init__.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/__init__.py
new file mode 100644
index 00000000..e32153fa
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/__init__.py
@@ -0,0 +1,73 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module contains the logic for doing model conversions to TensorRT."""
+from dataclasses import dataclass
+from typing import Optional
+
+from ..errors import ModelServerException
+from ..model import Model, ModelFormats, ModelTypes
+
+
+@dataclass
+class ConversionOptions:
+ """Class containing the options used in TRT conversion."""
+
+ max_input_length: int
+ max_output_length: int
+ pipline_parallelism: int
+ tensor_parallelism: int
+ vocab_size: Optional[int] = None
+
+
+def convert(model: Model, opts: ConversionOptions) -> None:
+ """
+ Convert the provided model to TensorRT.
+
+ Supported types and formats:
+ +----------+---------+---------+---------+---------+---------+
+ | | NEMO | PYTORCH | ONNX | HFACE | UNKNOWN |
+ +----------+---------+---------+---------+---------+---------+
+ | LLAMA | ✅ | ✅ | ❌ | ✅ | ❌ |
+ | GPTNEXT | ✅ | ❌ | ❌ | ❌ | ❌ |
+ +----------+---------+---------+---------+---------+---------+
+ """
+ if model.format == ModelFormats.NEMO:
+ # pylint: disable-next=import-outside-toplevel # preventing circular imports
+ from . import nemo
+
+ nemo.convert(model, opts)
+
+ elif model.type == ModelTypes.LLAMA:
+ # pylint: disable-next=import-outside-toplevel # preventing circular imports
+ from . import llama
+
+ opts.vocab_size = 32000
+ llama.convert(model, opts)
+
+ elif model.type == ModelTypes.CODE_LLAMA:
+ # pylint: disable-next=import-outside-toplevel # preventing circular imports
+ from . import llama
+
+ opts.vocab_size = 32016
+ llama.convert(model, opts)
+
+ else:
+ supported_types = [e.name for e in ModelTypes]
+ raise ModelServerException(
+ f"Unsupported model type. Conversion is supported for the following types: {supported_types}"
+ )
+
+ model.write_hash()
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/llama.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/llama.py
new file mode 100644
index 00000000..5dc21552
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/llama.py
@@ -0,0 +1,96 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module contains the logic for exporting a Llama model in PyTorch format to TensorRT."""
+import logging
+import os
+import subprocess
+import sys
+import typing
+
+from ..errors import ModelServerException, UnsupportedFormatException
+from ..model import Model
+from . import ConversionOptions
+
+_CONVERSION_SCRIPTS = "/opt/conversion_scripts/llama"
+
+_CHECKPOINT_ARGS_FLAGS = {"PYTORCH": "--meta_ckpt_dir", "HUGGINGFACE": "--model_dir"}
+_LOGGER = logging.getLogger(__name__)
+
+
+def convert(model: Model, opts: ConversionOptions) -> None:
+ """Convert a llama model."""
+ _LOGGER.debug("Running Llama model conversion.")
+
+ # construct builder executable path
+ cwd = _CONVERSION_SCRIPTS
+ exe = [sys.executable, "build.py"]
+
+ # construct builder env variables
+ env = os.environ
+
+ # construct builder arguments
+ try:
+ raw_args: typing.List[str] = [
+ "--max_input_len",
+ str(opts.max_input_length),
+ "--max_output_len",
+ str(opts.max_output_length),
+ "--dtype",
+ "float16",
+ "--use_gpt_attention_plugin",
+ "float16",
+ "--use_inflight_batching",
+ "--paged_kv_cache",
+ "--remove_input_padding",
+ "--use_gemm_plugin",
+ "float16",
+ "--output_dir",
+ model.engine_dir,
+ "--world_size",
+ str(model.world_size),
+ "--tp_size",
+ str(opts.tensor_parallelism),
+ "--pp_size",
+ str(opts.pipline_parallelism),
+ "--vocab_size",
+ str(opts.vocab_size),
+ _CHECKPOINT_ARGS_FLAGS[model.format.name],
+ model.model_dir,
+ ]
+ except KeyError as err:
+ raise UnsupportedFormatException(
+ model.format.name, ["PyTorch", "Hugging Face"]
+ ) from err
+
+ # start the builder
+ _LOGGER.debug(
+ "Starting Llama exporter with the command: %s", " ".join(exe + raw_args)
+ )
+ _LOGGER.debug("Starting Llama exporter with the env vars: %s", repr(env))
+ with subprocess.Popen(exe + raw_args, env=env, cwd=cwd) as proc:
+ try:
+ retcode = proc.wait()
+ except KeyboardInterrupt:
+ proc.kill()
+ except Exception as err:
+ raise ModelServerException(
+ "Error running TensorRT model conversion."
+ ) from err
+ else:
+ if retcode != 0:
+ raise ModelServerException(
+ "TensorRT conversion returned a non-zero exit code."
+ )
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/nemo.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/nemo.py
new file mode 100644
index 00000000..d0eb10f3
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/conversion/nemo.py
@@ -0,0 +1,71 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module contains the code for converting any .nemo formatted model to TRT LLM."""
+import logging
+import os
+from glob import glob
+from tarfile import TarFile
+from typing import IO, cast
+
+import yaml
+
+# pylint: disable-next=import-error
+from nemo.export import TensorRTLLM # type: ignore
+
+from ..errors import ModelServerException
+from ..model import Model
+from . import ConversionOptions
+
+_LOGGER = logging.getLogger(__name__)
+
+
+def convert(model: Model, _: ConversionOptions) -> None:
+ """Convert a .nemo formatted model."""
+ # find the .nemo model file
+ model_files = glob(os.path.join(model.model_dir, "*.nemo"))
+ if len(model_files) > 1:
+ raise ModelServerException(
+ "More than one NeMo checkpoint found in the model directory. "
+ + "Please only include one NeMo checkpoint file."
+ )
+
+ # verify that the model parallelism matchines the
+ config = {}
+ with TarFile(model_files[0], "r") as archive:
+ try:
+ config_file = cast(IO[bytes], archive.extractfile("./model_config.yaml"))
+ except KeyError:
+ config_file = cast(IO[bytes], archive.extractfile("model_config.yaml"))
+ config = yaml.safe_load(config_file)
+ config_file.close()
+
+ if config.get("tensor_model_parallel_size", 1) != model.world_size:
+ raise ModelServerException(
+ f"The provided model has a tensor parallelism of {config.get('tensor_model_parallel_size', 1)} "
+ + f"and the server has been requested to use {model.world_size} "
+ + "gpus. Please use the NeMo inference container to rezise the parallelism of the model or change "
+ + "the model-server's world size."
+ )
+
+ # run the nemo to trt llm conversion
+ trt_llm_exporter = TensorRTLLM(model_dir=model.engine_dir)
+ _LOGGER.info(".nemo to TensorRT Conversion started. This will take a few minutes.")
+ _LOGGER.info(model.engine_dir)
+ trt_llm_exporter.export(
+ nemo_checkpoint_path=model_files[0],
+ model_type=model.family,
+ n_gpus=model.world_size,
+ )
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/errors.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/errors.py
new file mode 100644
index 00000000..5609a58a
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/errors.py
@@ -0,0 +1,32 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""The custom errors raised by the model server."""
+import typing
+
+
+class ModelServerException(Exception):
+ """The base class for any custom expections."""
+
+
+class UnsupportedFormatException(ModelServerException):
+ """An error that indicates the model format is not supported for the provided type."""
+
+ def __init__(self, model_type: str, supported: typing.List[str]):
+ """Initialize the exception."""
+ super().__init__(
+ "Unsupported model type and format combination. "
+ + f"{model_type} models are supported in the following formats: {str(supported)}"
+ )
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/model.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/model.py
new file mode 100644
index 00000000..0f83459b
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/model.py
@@ -0,0 +1,246 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module contains the model class that represents the model mounted to the container."""
+import glob
+import hashlib
+import logging
+import os
+import pathlib
+import subprocess
+import typing
+from enum import Enum, auto, unique
+
+from .errors import ModelServerException
+
+DEFAULT_MODEL_DIR = "/model"
+HASH_COMMAND = "sha1sum"
+_LOGGER = logging.getLogger(__name__)
+
+
+def _fast_hash_dir(dir_path: str) -> str:
+ """
+ Read the files in a directory and quickly create a hash.
+
+ This hash IS NOT cryptographically secure, but it is designed to be computed as quickly as reasonably possible.
+ This function will only hash top level files and will not traverse directories.
+ """
+ # create a threaded pool of workers to calculate individual hases
+ workers = []
+ for obj in os.listdir(dir_path):
+ obj_path = os.path.join(dir_path, obj)
+ if not os.path.isfile(obj_path):
+ continue
+
+ workers += [
+ # pylint: disable-next=consider-using-with
+ subprocess.Popen(
+ [HASH_COMMAND, obj_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE
+ )
+ ]
+
+ # wait for workers to complete
+ all_shas = b""
+ for proc in workers:
+ stdout, _ = proc.communicate()
+ all_shas += stdout.split(b" ", maxsplit=1)[0]
+
+ hasher = hashlib.sha1(usedforsecurity=False)
+ hasher.update(all_shas)
+ return hasher.hexdigest()
+
+
+@unique
+class ModelFormats(Enum):
+ """A Enumerator containing all of the supported model types."""
+
+ UNKNOWN = auto()
+ ONNX = auto()
+ PYTORCH = auto()
+ HUGGINGFACE = auto()
+ NEMO = auto()
+
+
+@unique
+class ModelTypes(Enum):
+ """A enumerator of the supported model types."""
+
+ LLAMA = auto()
+ CODE_LLAMA = auto()
+ GPTNEXT = auto()
+
+ @property
+ def family(self) -> str:
+ """Return the family grouping of the model."""
+ return ["llama", "llama", "gptnext"][self.value - 1]
+
+
+class Model:
+ """A representation of the mounted model."""
+
+ def __init__(
+ self,
+ model_type: str,
+ model_dir: typing.Optional[str] = None,
+ world_size: typing.Optional[int] = None,
+ ):
+ """Initialize the model class."""
+ try:
+ self._type = ModelTypes[model_type.upper().replace("-", "_")]
+ except KeyError as err:
+ raise ModelServerException(f"Unrecognized model type {type}") from err
+
+ self._model_dir = model_dir or DEFAULT_MODEL_DIR
+ self._gpu_info = self._init_gpu_info(world_size=world_size)
+ self._hash: typing.Optional[str] = None
+ self._engine_dir = self._init_engine_dir()
+ self._format = self._init_model_format()
+
+ @classmethod
+ def _init_gpu_info(
+ cls,
+ world_size: typing.Optional[int] = None,
+ ) -> typing.Dict[str, typing.Union[str, int]]:
+ """
+ Get the product name and architecture for the first GPU in the system.
+
+ Returns
+ -------
+ Tuple: A tuple of the product name and architecture.
+ """
+ query_cmd = ["nvidia-smi", "--query-gpu=compute_cap", "--format=csv,noheader"]
+ gpu_info_raw = subprocess.check_output(query_cmd)
+ compute_caps = [cap.decode() for cap in gpu_info_raw.strip().split(b"\n")]
+ # FUTURE: convert this to use nvml instead
+
+ # do basic error checking
+ if len(compute_caps) == 0:
+ raise ModelServerException("No GPUs attached to the container.")
+ if len(set(compute_caps)) > 1:
+ raise ModelServerException(
+ "Attached GPUs are dissimilar. All GPUs must be of the same type."
+ )
+ if not world_size:
+ world_size = len(compute_caps)
+
+ return {"compute_cap": compute_caps[0], "world_size": world_size}
+
+ def _init_engine_dir(self) -> str:
+ """Create and return the path to the TensorRT cache directory for this model."""
+ cache_dir = f"trt-w{self.world_size}-cc{self.compute_cap}"
+ cache_path = os.path.join(self.model_dir, cache_dir)
+ pathlib.Path(cache_path).mkdir(parents=True, exist_ok=True)
+ return cache_path
+
+ def _init_model_format(self) -> ModelFormats:
+ """Determine the format of model that has been mounted."""
+ # look for nemo checkpoints
+ nemo_count = self._file_ext_count("nemo")
+ if nemo_count == 1:
+ return ModelFormats.NEMO
+ if nemo_count > 1:
+ raise ModelServerException(
+ f"Only one nemo checkpoint file may be in the model directory. Found {nemo_count}",
+ )
+
+ # look for pytorch saved models
+ pytorch_count = self._file_ext_count("pth") + self._file_ext_count("pt")
+ if pytorch_count:
+ return ModelFormats.PYTORCH
+
+ # look for huggingface saved models
+ hf_count = self._file_ext_count("bin")
+ if hf_count:
+ return ModelFormats.HUGGINGFACE
+
+ # look for onnx models
+ onnx_count = self._file_ext_count("onnx")
+ if onnx_count:
+ return ModelFormats.ONNX
+
+ return ModelFormats.UNKNOWN
+
+ def _file_ext_count(self, extension: str) -> int:
+ """Count the files in a directory with a given extension."""
+ path = os.path.join(self.model_dir, f"*.{extension}")
+ return len(glob.glob(path))
+
+ @property
+ def type(self) -> ModelTypes:
+ """Return the type of the model."""
+ return self._type
+
+ @property
+ def family(self) -> str:
+ """Return the model family grouping."""
+ return self._type.family
+
+ @property
+ def model_dir(self) -> str:
+ """Return the stored model directory."""
+ return self._model_dir
+
+ @property
+ def engine_dir(self) -> str:
+ """Return the stored engine directory."""
+ return self._engine_dir
+
+ @property
+ def world_size(self) -> int:
+ """Return the world size."""
+ ws = self._gpu_info["world_size"]
+ return typing.cast(int, ws)
+
+ @property
+ def compute_cap(self) -> str:
+ """Return the compute capability version."""
+ cc = self._gpu_info["compute_cap"]
+ return typing.cast(str, cc)
+
+ @property
+ def format(self) -> ModelFormats:
+ """Return the format of the model."""
+ return self._format
+
+ @property
+ def hash(self) -> str:
+ """Return the hash of the model."""
+ if not self._hash:
+ _LOGGER.info("Calculating model hash.")
+ self._hash = _fast_hash_dir(self.model_dir)
+ return self._hash
+
+ @property
+ def _last_hash_path(self) -> str:
+ """Return the path to the last known hash file."""
+ return os.path.join(self.engine_dir, "hash")
+
+ def conversion_is_needed(self) -> bool:
+ """Determine if the engine conversion is required."""
+ if not os.path.isfile(self._last_hash_path):
+ _LOGGER.debug("No engine file exists. Will generate an engine file.")
+ return True
+ with open(self._last_hash_path, "r", encoding="ASCII") as hash_file:
+ last_hash = hash_file.read()
+ if last_hash != self.hash:
+ _LOGGER.debug("Change in model hash detected. Will regnerate engine file.")
+ return True
+ _LOGGER.debug("Existing engine file found.")
+ return False
+
+ def write_hash(self) -> None:
+ """Write the model hash to the engine directory."""
+ with open(self._last_hash_path, "w", encoding="ASCII") as hash_file:
+ hash_file.write(self.hash)
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server/server.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server/server.py
new file mode 100644
index 00000000..a9077bf9
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server/server.py
@@ -0,0 +1,153 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This module contains the code to statup triton inference servers."""
+import logging
+import os
+import subprocess
+import typing
+
+from jinja2 import Environment, FileSystemLoader
+
+from .model import Model, ModelFormats
+
+_ENSEMBLE_MODEL_DIR = "/opt/ensemble_models"
+_TRITON_BIN = "/opt/tritonserver/bin/tritonserver"
+_MPIRUN_BIN = "/usr/local/mpi/bin/mpirun"
+_LOGGER = logging.getLogger(__name__)
+
+
+class ModelServer:
+ """Abstraction of a multi-gpu triton inference server cluster."""
+
+ def __init__(self, model: Model, http: bool = False) -> None:
+ """Initialize the model server."""
+ self._model = model
+ self._http = http
+
+ @property
+ def _decoupled_mode(self) -> str:
+ """Indicate if the Triton models should be hosted in decoupled mode for streaming."""
+ if self._model.format == ModelFormats.NEMO:
+ return "false"
+ return "true" if not self._http else "false"
+
+ @property
+ def _allow_http(self) -> str:
+ """Indicate if Triton should allow http connections."""
+ return "true" if self._http else "false"
+
+ @property
+ def _allow_grpc(self) -> str:
+ """Inidicate if Triton should allow grpc connections."""
+ return "true" if not self._http else "false"
+
+ @property
+ def _tokenizer_model_dir(self) -> str:
+ """Inidicate where the tokenizer model can be found."""
+ if self._model.format == ModelFormats.NEMO:
+ return self._model.engine_dir
+ return self._model.model_dir
+
+ @property
+ def _gpt_model_type(self) -> str:
+ """Indicate the TRT LLM Backend mode."""
+ if self._model.format == ModelFormats.NEMO:
+ return "V1"
+ return "inflight_fused_batching"
+
+ @property
+ def model_repository(self) -> str:
+ """Return the triton model repository."""
+ return os.path.join(_ENSEMBLE_MODEL_DIR, self._model.family)
+
+ def _triton_server_cmd(self, rank: int) -> typing.List[str]:
+ """Generate the command to start a single triton server of given rank."""
+ return [
+ "-n",
+ "1",
+ _TRITON_BIN,
+ "--allow-http",
+ self._allow_http,
+ "--allow-grpc",
+ self._allow_grpc,
+ "--model-repository",
+ self.model_repository,
+ "--disable-auto-complete-config",
+ f"--backend-config=python,shm-region-prefix-name=prefix{rank}_",
+ ":",
+ ]
+
+ @property
+ def _cmd(self) -> typing.List[str]:
+ """Generate the full command."""
+ cmd = [_MPIRUN_BIN]
+ for rank in range(self._model.world_size):
+ cmd += self._triton_server_cmd(rank)
+ return cmd
+
+ @property
+ def _env(self) -> typing.Dict[str, str]:
+ """Return the environment variable for the triton inference server."""
+ env = dict(os.environ)
+ env["TRT_ENGINE_DIR"] = self._model.engine_dir
+ env["TOKENIZER_DIR"] = self._tokenizer_model_dir
+ if os.getuid() == 0:
+ _LOGGER.warning(
+ "Triton server will be running as root. It is recommended that you don't run this container as root."
+ )
+ env["OMPI_ALLOW_RUN_AS_ROOT"] = "1"
+ env["OMPI_ALLOW_RUN_AS_ROOT_CONFIRM"] = "1"
+ return env
+
+ def _render_model_templates(self) -> None:
+ """Render and Jinja templates in the model directory."""
+ env = Environment(
+ loader=FileSystemLoader(searchpath=self.model_repository),
+ autoescape=False,
+ ) # nosec; all the provided values are from code, not the user
+
+ template_path = os.path.join("tensorrt_llm", "config.pbtxt.j2")
+ output_path = os.path.join(
+ self.model_repository, "tensorrt_llm", "config.pbtxt"
+ )
+
+ template = env.get_template(template_path)
+
+ with open(output_path, "w", encoding="UTF-8") as out:
+ template_args = {
+ "engine_dir": self._model.engine_dir,
+ "decoupled_mode": self._decoupled_mode,
+ "gpt_model_type": self._gpt_model_type,
+ }
+ out.write(template.render(**template_args))
+
+ def run(self) -> int:
+ """Start the triton inference server."""
+ cmd = self._cmd
+ env = self._env
+
+ _LOGGER.debug("Rendering the ensemble models.")
+ self._render_model_templates()
+
+ _LOGGER.debug("Starting triton with the command: %s", " ".join(cmd))
+ _LOGGER.debug("Starting triton with the env vars: %s", repr(env))
+ with subprocess.Popen(cmd, env=env) as proc:
+ try:
+ retcode = proc.wait()
+ except KeyboardInterrupt:
+ proc.kill()
+ return 0
+ return retcode
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/model_server_client/trt_llm.py b/RetrievalAugmentedGeneration/llm-inference-server/model_server_client/trt_llm.py
new file mode 100644
index 00000000..7291db4c
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/model_server_client/trt_llm.py
@@ -0,0 +1,544 @@
+# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""A Langchain LLM component for connecting to Triton + TensorRT LLM backend."""
+# pylint: disable=too-many-lines
+import abc
+import json
+import queue
+import random
+import time
+from functools import partial
+from typing import Any, Callable, Dict, List, Optional, Type, Union
+
+import google.protobuf.json_format
+import numpy as np
+import tritonclient.grpc as grpcclient
+import tritonclient.http as httpclient
+from tritonclient.grpc.service_pb2 import ModelInferResponse
+from tritonclient.utils import np_to_triton_dtype
+
+try:
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
+ from langchain.llms.base import LLM
+ from langchain.pydantic_v1 import Field, root_validator
+
+ USE_LANGCHAIN = True
+except ImportError:
+ USE_LANGCHAIN = False
+
+
+STOP_WORDS = [""]
+RANDOM_SEED = 0
+
+if USE_LANGCHAIN:
+ # pylint: disable-next=too-few-public-methods # Interface is defined by LangChain
+ class TensorRTLLM(LLM): # type: ignore # LLM class not typed in langchain
+ """A custom Langchain LLM class that integrates with TRTLLM triton models.
+
+ Arguments:
+ server_url: (str) The URL of the Triton inference server to use.
+ model_name: (str) The name of the Triton TRT model to use.
+ temperature: (str) Temperature to use for sampling
+ top_p: (float) The top-p value to use for sampling
+ top_k: (float) The top k values use for sampling
+ beam_width: (int) Last n number of tokens to penalize
+ repetition_penalty: (int) Last n number of tokens to penalize
+ length_penalty: (float) The penalty to apply repeated tokens
+ tokens: (int) The maximum number of tokens to generate.
+ client: The client object used to communicate with the inference server
+ """
+
+ server_url: str = Field(None, alias="server_url")
+
+ # # all the optional arguments
+ model_name: str = "ensemble"
+ temperature: Optional[float] = 1.0
+ top_p: Optional[float] = 0
+ top_k: Optional[int] = 1
+ tokens: Optional[int] = 100
+ beam_width: Optional[int] = 1
+ repetition_penalty: Optional[float] = 1.0
+ length_penalty: Optional[float] = 1.0
+ client: Any
+ streaming: Optional[bool] = True
+
+ @root_validator() # type: ignore # typing not declared in langchain
+ @classmethod
+ def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
+ """Validate that python package exists in environment."""
+ try:
+ if values.get("streaming", True):
+ values["client"] = GrpcTritonClient(values["server_url"])
+ else:
+ values["client"] = HttpTritonClient(values["server_url"])
+
+ except ImportError as err:
+ raise ImportError(
+ "Could not import triton client python package. "
+ "Please install it with `pip install tritonclient[all]`."
+ ) from err
+ return values
+
+ @property
+ def _get_model_default_parameters(self) -> Dict[str, Any]:
+ return {
+ "tokens": self.tokens,
+ "top_k": self.top_k,
+ "top_p": self.top_p,
+ "temperature": self.temperature,
+ "repetition_penalty": self.repetition_penalty,
+ "length_penalty": self.length_penalty,
+ "beam_width": self.beam_width,
+ }
+
+ @property
+ def _invocation_params(self, **kwargs: Any) -> Dict[str, Any]:
+ params = {**self._get_model_default_parameters, **kwargs}
+ return params
+
+ @property
+ def _identifying_params(self) -> Dict[str, Any]:
+ """Get all the identifying parameters."""
+ return {
+ "server_url": self.server_url,
+ "model_name": self.model_name,
+ }
+
+ @property
+ def _llm_type(self) -> str:
+ return "triton_tensorrt"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None, # pylint: disable=unused-argument
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ """
+ Execute an inference request.
+
+ Args:
+ prompt: The prompt to pass into the model.
+ stop: A list of strings to stop generation when encountered
+
+ Returns:
+ The string generated by the model
+ """
+ text_callback = None
+ if run_manager:
+ text_callback = partial(
+ run_manager.on_llm_new_token, verbose=self.verbose
+ )
+
+ invocation_params = self._get_model_default_parameters
+ invocation_params.update(kwargs)
+ invocation_params["prompt"] = [[prompt]]
+ model_params = self._identifying_params
+ model_params.update(kwargs)
+ request_id = str(random.randint(1, 9999999)) # nosec
+
+ self.client.load_model(model_params["model_name"])
+ if isinstance(self.client, GrpcTritonClient):
+ return self._streaming_request(
+ model_params, request_id, invocation_params, text_callback
+ )
+ return self._request(model_params, invocation_params, text_callback)
+
+ def _streaming_request(
+ self,
+ model_params: Dict[str, Any],
+ request_id: str,
+ invocation_params: Dict[str, Any],
+ text_callback: Optional[Callable[[str], None]],
+ ) -> str:
+ """Request a streaming inference session."""
+ result_queue = self.client.request_streaming(
+ model_params["model_name"], request_id, **invocation_params
+ )
+
+ response = ""
+ for token in result_queue:
+ if text_callback:
+ text_callback(token)
+ response = response + token
+ return response
+
+ def _request(
+ self,
+ model_params: Dict[str, Any],
+ invocation_params: Dict[str, Any],
+ text_callback: Optional[Callable[[str], None]],
+ ) -> str:
+ """Request a streaming inference session."""
+ token: str = self.client.request(
+ model_params["model_name"], **invocation_params
+ )
+ if text_callback:
+ text_callback(token)
+ return token
+
+
+class StreamingResponseGenerator(queue.Queue[Optional[str]]):
+ """A Generator that provides the inference results from an LLM."""
+
+ def __init__(
+ self, client: "GrpcTritonClient", request_id: str, force_batch: bool
+ ) -> None:
+ """Instantiate the generator class."""
+ super().__init__()
+ self._client = client
+ self.request_id = request_id
+ self._batch = force_batch
+
+ def __iter__(self) -> "StreamingResponseGenerator":
+ """Return self as a generator."""
+ return self
+
+ def __next__(self) -> str:
+ """Return the next retrieved token."""
+ val = self.get()
+ if val is None or val in STOP_WORDS:
+ self._stop_stream()
+ raise StopIteration()
+ return val
+
+ def _stop_stream(self) -> None:
+ """Drain and shutdown the Triton stream."""
+ self._client.stop_stream(
+ "tensorrt_llm", self.request_id, signal=not self._batch
+ )
+
+
+class _BaseTritonClient(abc.ABC):
+ """An abstraction of the connection to a triton inference server."""
+
+ def __init__(self, server_url: str) -> None:
+ """Initialize the client."""
+ self._server_url = server_url
+ self._client = self._inference_server_client(server_url)
+
+ @property
+ @abc.abstractmethod
+ def _inference_server_client(
+ self,
+ ) -> Union[
+ Type[grpcclient.InferenceServerClient], Type[httpclient.InferenceServerClient]
+ ]:
+ """Return the prefered InferenceServerClient class."""
+
+ @property
+ @abc.abstractmethod
+ def _infer_input(
+ self,
+ ) -> Union[Type[grpcclient.InferInput], Type[httpclient.InferInput]]:
+ """Return the preferred InferInput."""
+
+ @property
+ @abc.abstractmethod
+ def _infer_output(
+ self,
+ ) -> Union[
+ Type[grpcclient.InferRequestedOutput], Type[httpclient.InferRequestedOutput]
+ ]:
+ """Return the preferred InferRequestedOutput."""
+
+ def load_model(self, model_name: str, timeout: int = 1000) -> None:
+ """Load a model into the server."""
+ if self._client.is_model_ready(model_name):
+ return
+
+ self._client.load_model(model_name)
+ t0 = time.perf_counter()
+ t1 = t0
+ while not self._client.is_model_ready(model_name) and t1 - t0 < timeout:
+ t1 = time.perf_counter()
+
+ if not self._client.is_model_ready(model_name):
+ raise RuntimeError(f"Failed to load {model_name} on Triton in {timeout}s")
+
+ def get_model_list(self) -> List[str]:
+ """Get a list of models loaded in the triton server."""
+ res = self._client.get_model_repository_index(as_json=True)
+ return [model["name"] for model in res["models"]]
+
+ def get_model_concurrency(self, model_name: str, timeout: int = 1000) -> int:
+ """Get the modle concurrency."""
+ self.load_model(model_name, timeout)
+ instances = self._client.get_model_config(model_name, as_json=True)["config"][
+ "instance_group"
+ ]
+ return sum(instance["count"] * len(instance["gpus"]) for instance in instances)
+
+ def _generate_stop_signals(
+ self,
+ ) -> List[Union[grpcclient.InferInput, httpclient.InferInput]]:
+ """Generate the signal to stop the stream."""
+ inputs = [
+ self._infer_input("input_ids", [1, 1], "INT32"),
+ self._infer_input("input_lengths", [1, 1], "INT32"),
+ self._infer_input("request_output_len", [1, 1], "UINT32"),
+ self._infer_input("stop", [1, 1], "BOOL"),
+ ]
+ inputs[0].set_data_from_numpy(np.empty([1, 1], dtype=np.int32))
+ inputs[1].set_data_from_numpy(np.zeros([1, 1], dtype=np.int32))
+ inputs[2].set_data_from_numpy(np.array([[0]], dtype=np.uint32))
+ inputs[3].set_data_from_numpy(np.array([[True]], dtype="bool"))
+ return inputs
+
+ def _generate_outputs(
+ self,
+ ) -> List[Union[grpcclient.InferRequestedOutput, httpclient.InferRequestedOutput]]:
+ """Generate the expected output structure."""
+ return [self._infer_output("text_output")]
+
+ def _prepare_tensor(
+ self, name: str, input_data: Any
+ ) -> Union[grpcclient.InferInput, httpclient.InferInput]:
+ """Prepare an input data structure."""
+ t = self._infer_input(
+ name, input_data.shape, np_to_triton_dtype(input_data.dtype)
+ )
+ t.set_data_from_numpy(input_data)
+ return t
+
+ def _generate_inputs( # pylint: disable=too-many-arguments,too-many-locals
+ self,
+ prompt: str,
+ tokens: int = 300,
+ temperature: float = 1.0,
+ top_k: float = 1,
+ top_p: float = 0,
+ beam_width: int = 1,
+ repetition_penalty: float = 1,
+ length_penalty: float = 1.0,
+ stream: bool = True,
+ ) -> List[Union[grpcclient.InferInput, httpclient.InferInput]]:
+ """Create the input for the triton inference server."""
+ query = np.array(prompt).astype(object)
+ request_output_len = np.array([tokens]).astype(np.uint32).reshape((1, -1))
+ runtime_top_k = np.array([top_k]).astype(np.uint32).reshape((1, -1))
+ runtime_top_p = np.array([top_p]).astype(np.float32).reshape((1, -1))
+ temperature_array = np.array([temperature]).astype(np.float32).reshape((1, -1))
+ len_penalty = np.array([length_penalty]).astype(np.float32).reshape((1, -1))
+ repetition_penalty_array = (
+ np.array([repetition_penalty]).astype(np.float32).reshape((1, -1))
+ )
+ random_seed = np.array([RANDOM_SEED]).astype(np.uint64).reshape((1, -1))
+ beam_width_array = np.array([beam_width]).astype(np.uint32).reshape((1, -1))
+ streaming_data = np.array([[stream]], dtype=bool)
+
+ inputs = [
+ self._prepare_tensor("text_input", query),
+ self._prepare_tensor("max_tokens", request_output_len),
+ self._prepare_tensor("top_k", runtime_top_k),
+ self._prepare_tensor("top_p", runtime_top_p),
+ self._prepare_tensor("temperature", temperature_array),
+ self._prepare_tensor("length_penalty", len_penalty),
+ self._prepare_tensor("repetition_penalty", repetition_penalty_array),
+ self._prepare_tensor("random_seed", random_seed),
+ self._prepare_tensor("beam_width", beam_width_array),
+ self._prepare_tensor("stream", streaming_data),
+ ]
+ return inputs
+
+ def _trim_batch_response(self, result_str: str) -> str:
+ """Trim the resulting response from a batch request by removing provided prompt and extra generated text."""
+ # extract the generated part of the prompt
+ split = result_str.split("[/INST]", 1)
+ generated = split[-1]
+ end_token = generated.find("")
+ if end_token == -1:
+ return generated
+ generated = generated[:end_token].strip()
+ return generated
+
+
+class GrpcTritonClient(_BaseTritonClient):
+ """GRPC connection to a triton inference server."""
+
+ @property
+ def _inference_server_client(
+ self,
+ ) -> Type[grpcclient.InferenceServerClient]:
+ """Return the prefered InferenceServerClient class."""
+ return grpcclient.InferenceServerClient # type: ignore
+
+ @property
+ def _infer_input(self) -> Type[grpcclient.InferInput]:
+ """Return the preferred InferInput."""
+ return grpcclient.InferInput # type: ignore
+
+ @property
+ def _infer_output(
+ self,
+ ) -> Type[grpcclient.InferRequestedOutput]:
+ """Return the preferred InferRequestedOutput."""
+ return grpcclient.InferRequestedOutput # type: ignore
+
+ def _send_stop_signals(self, model_name: str, request_id: str) -> None:
+ """Send the stop signal to the Triton Inference server."""
+ stop_inputs = self._generate_stop_signals()
+ self._client.async_stream_infer(
+ model_name,
+ stop_inputs,
+ request_id=request_id,
+ parameters={"Streaming": True},
+ )
+
+ @staticmethod
+ def _process_result(result: Dict[str, str]) -> str:
+ """Post-process the result from the server."""
+ message = ModelInferResponse()
+ generated_text: str = ""
+ google.protobuf.json_format.Parse(json.dumps(result), message)
+ infer_result = grpcclient.InferResult(message)
+ np_res = infer_result.as_numpy("text_output")
+
+ generated_text = ""
+ if np_res is not None:
+ generated_text = "".join([token.decode() for token in np_res])
+
+ return generated_text
+
+ def _stream_callback(
+ self,
+ result_queue: queue.Queue[Union[Optional[Dict[str, str]], str]],
+ force_batch: bool,
+ result: Any,
+ error: str,
+ ) -> None:
+ """Add streamed result to queue."""
+ if error:
+ result_queue.put(error)
+ else:
+ response_raw = result.get_response(as_json=True)
+ if "outputs" in response_raw:
+ # the very last response might have no output, just the final flag
+ response = self._process_result(response_raw)
+ if force_batch:
+ response = self._trim_batch_response(response)
+
+ if response in STOP_WORDS:
+ result_queue.put(None)
+ else:
+ result_queue.put(response)
+
+ if response_raw["parameters"]["triton_final_response"]["bool_param"]:
+ # end of the generation
+ result_queue.put(None)
+
+ # pylint: disable-next=too-many-arguments
+ def _send_prompt_streaming(
+ self,
+ model_name: str,
+ request_inputs: Any,
+ request_outputs: Optional[Any],
+ request_id: str,
+ result_queue: StreamingResponseGenerator,
+ force_batch: bool = False,
+ ) -> None:
+ """Send the prompt and start streaming the result."""
+ self._client.start_stream(
+ callback=partial(self._stream_callback, result_queue, force_batch)
+ )
+ self._client.async_stream_infer(
+ model_name=model_name,
+ inputs=request_inputs,
+ outputs=request_outputs,
+ request_id=request_id,
+ )
+
+ def request_streaming(
+ self,
+ model_name: str,
+ request_id: Optional[str] = None,
+ force_batch: bool = False,
+ **params: Any,
+ ) -> StreamingResponseGenerator:
+ """Request a streaming connection."""
+ if not self._client.is_model_ready(model_name):
+ raise RuntimeError("Cannot request streaming, model is not loaded")
+
+ if not request_id:
+ request_id = str(random.randint(1, 9999999)) # nosec
+
+ result_queue = StreamingResponseGenerator(self, request_id, force_batch)
+ inputs = self._generate_inputs(stream=not force_batch, **params)
+ outputs = self._generate_outputs()
+ self._send_prompt_streaming(
+ model_name,
+ inputs,
+ outputs,
+ request_id,
+ result_queue,
+ force_batch,
+ )
+ return result_queue
+
+ def stop_stream(
+ self, model_name: str, request_id: str, signal: bool = True
+ ) -> None:
+ """Close the streaming connection."""
+ if signal:
+ self._send_stop_signals(model_name, request_id)
+ self._client.stop_stream()
+
+
+class HttpTritonClient(_BaseTritonClient):
+ """HTTP connection to a triton inference server."""
+
+ @property
+ def _inference_server_client(
+ self,
+ ) -> Type[httpclient.InferenceServerClient]:
+ """Return the prefered InferenceServerClient class."""
+ return httpclient.InferenceServerClient # type: ignore
+
+ @property
+ def _infer_input(self) -> Type[httpclient.InferInput]:
+ """Return the preferred InferInput."""
+ return httpclient.InferInput # type: ignore
+
+ @property
+ def _infer_output(
+ self,
+ ) -> Type[httpclient.InferRequestedOutput]:
+ """Return the preferred InferRequestedOutput."""
+ return httpclient.InferRequestedOutput # type: ignore
+
+ def request(
+ self,
+ model_name: str,
+ **params: Any,
+ ) -> str:
+ """Request inferencing from the triton server."""
+ if not self._client.is_model_ready(model_name):
+ raise RuntimeError("Cannot request streaming, model is not loaded")
+
+ # create model inputs and outputs
+ inputs = self._generate_inputs(stream=False, **params)
+ outputs = self._generate_outputs()
+
+ # call the model for inference
+ result = self._client.infer(model_name, inputs=inputs, outputs=outputs)
+ result_str = "".join(
+ [val.decode("utf-8") for val in result.as_numpy("text_output").tolist()]
+ )
+
+ # extract the generated part of the prompt
+ # return(result_str)
+ return self._trim_batch_response(result_str)
diff --git a/RetrievalAugmentedGeneration/llm-inference-server/requirements.txt b/RetrievalAugmentedGeneration/llm-inference-server/requirements.txt
new file mode 100644
index 00000000..3f6a7734
--- /dev/null
+++ b/RetrievalAugmentedGeneration/llm-inference-server/requirements.txt
@@ -0,0 +1,7 @@
+jinja2
+langchain
+numpy
+protobuf
+requests
+tritonclient[all]
+pyyaml
diff --git a/SECURITY.md b/SECURITY.md
new file mode 100644
index 00000000..35180962
--- /dev/null
+++ b/SECURITY.md
@@ -0,0 +1,24 @@
+ ## Security
+
+NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization.
+
+If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub.**
+
+## Reporting Potential Security Vulnerability in an NVIDIA Product
+
+To report a potential security vulnerability in any NVIDIA product:
+- Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html)
+- E-Mail: psirt@nvidia.com
+ - We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key)
+ - Please include the following information:
+ - Product/Driver name and version/branch that contains the vulnerability
+ - Type of vulnerability (code execution, denial of service, buffer overflow, etc.)
+ - Instructions to reproduce the vulnerability
+ - Proof-of-concept or exploit code
+ - Potential impact of the vulnerability, including how an attacker could exploit the vulnerability
+
+While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information.
+
+## NVIDIA Product Security
+
+For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security
\ No newline at end of file