From 8222b4b771f980a526e5f0f2bc29d4ff1abd6e2a Mon Sep 17 00:00:00 2001 From: Cheng-Hsin Weng Date: Mon, 27 Apr 2026 13:30:33 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - Refactor llama runner for dynamic IO dtypes - Summary To enable GPU backend support in the Llama runner, refactoring is required because the dtypes of kv_cache, attention_mask, and logits are currently hardcoded, preventing floating point models from running. This PR fix the issue by removing the hardcoded dtype in the runner. --- backends/qualcomm/_passes/build_quant_io.py | 48 +-- backends/qualcomm/tests/test_qnn_delegate.py | 15 +- backends/qualcomm/tests/utils.py | 1 + .../llama/decoder_runtime_evaluator.py | 2 +- examples/qualcomm/oss_scripts/llama/llama.py | 64 ++- .../oss_scripts/llama/qnn_llama_runner.cpp | 25 +- .../llama/qnn_multimodal_runner.cpp | 38 +- .../oss_scripts/llama/runner/decoder_runner.h | 28 +- .../oss_scripts/llama/runner/kv_manager.cpp | 366 +++++++++++------- .../oss_scripts/llama/runner/kv_manager.h | 43 +- .../llama/runner/lhd_token_generator.cpp | 29 +- .../llama/runner/lhd_token_generator.h | 18 +- .../multimodal_lhd_token_generator.cpp | 26 +- .../multimodal_lhd_token_generator.h | 18 +- .../multimodal_prompt_processor.cpp | 53 ++- .../multimodal_prompt_processor.h | 51 ++- .../multimodal_runner/multimodal_runner.cpp | 73 ++-- .../multimodal_runner/multimodal_runner.h | 12 +- .../multimodal_token_generator.cpp | 50 +-- .../multimodal_token_generator.h | 43 +- .../llama/runner/prompt_processor.cpp | 84 ++-- .../llama/runner/prompt_processor.h | 30 +- .../oss_scripts/llama/runner/runner.cpp | 71 ++-- .../oss_scripts/llama/runner/runner.h | 13 +- .../llama/runner/token_generator.cpp | 80 ++-- .../llama/runner/token_generator.h | 30 +- .../qualcomm/oss_scripts/llama/runner/utils.h | 41 ++ .../llama/wrappers/attention_sink_wrappers.py | 2 + .../llama/wrappers/llm_wrappers.py | 64 +-- exir/passes/spec_prop_pass.py | 10 +- extension/android/jni/jni_layer_llama.cpp | 43 +- extension/llm/custom_ops/model_sharding.py | 24 +- extension/llm/custom_ops/op_fallback.py | 29 ++ 33 files changed, 811 insertions(+), 713 deletions(-) create mode 100644 extension/llm/custom_ops/op_fallback.py diff --git a/backends/qualcomm/_passes/build_quant_io.py b/backends/qualcomm/_passes/build_quant_io.py index d43842e84a5..057dcc0f864 100644 --- a/backends/qualcomm/_passes/build_quant_io.py +++ b/backends/qualcomm/_passes/build_quant_io.py @@ -5,11 +5,10 @@ # LICENSE file in the root directory of this source tree. import torch from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO -from executorch.exir.delegate import executorch_call_delegate -from executorch.exir.pass_base import ExportPass, ProxyValue +from executorch.exir.delegate import executorch_call_delegate +from executorch.exir.pass_base import ExportPass, PassResult from executorch.exir.tensor import TensorSpec -from torch.utils import _pytree as pytree class BuildQuantIo(ExportPass): @@ -28,22 +27,27 @@ def _make_spec(self, x): else: return None - def placeholder(self, name: str, arg, meta): - if quantized_dtype := meta.data.get(QCOM_QUANTIZED_IO, None): - arg = arg.to(dtype=quantized_dtype) - meta["spec"] = self._make_spec(arg) - return super().placeholder(name, arg, meta) - - def call_getitem(self, value, key: int, meta): - meta["spec"] = value.node.meta["spec"][key] - return super().call_getitem(value, key, meta) - - def call_delegate(self, lowered_module, args, kwargs, meta): - args_data, _ = pytree.tree_map_only( - ProxyValue, lambda x: x.data, (args, kwargs) - ) - meta["spec"] = pytree.tree_map( - self._make_spec, - executorch_call_delegate(lowered_module, *args_data), - ) - return super().call_delegate(lowered_module, args, kwargs, meta) + def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: + # Forcedly update delegate node's meta['spec'] to get correct output + # tensor size in runtime + call_delegates = [ + node + for node in graph_module.graph.nodes + if node.op == "call_function" and node.target == executorch_call_delegate + ] + for n in graph_module.graph.nodes: + if QCOM_QUANTIZED_IO in n.meta: + n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO]) + n.meta["spec"] = self._make_spec(n.meta["val"]) + + for call_delegate in call_delegates: + spec = [] + for user in list(call_delegate.users): + spec.append(self._make_spec(user.meta["val"])) + call_delegate.meta["spec"] = tuple(spec) + + def call(self, graph_module: torch.fx.GraphModule): + self._build(graph_module) + graph_module.graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index 688dddf5c2a..e6ae833ccbd 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -7475,6 +7475,8 @@ def test_llama_stories_110m(self): "--max_context_len", "128", ] + if self.use_fp16: + cmds.append("--use_fp16") self.add_default_cmds(cmds) golden_start_with = "Once upon a time," @@ -7495,7 +7497,10 @@ def test_llama_stories_110m(self): # x86 does not allow weight sharing, so we don't check pte size if not self.enable_x86_64: pte_size = msg["pte_size"] - self.assertLessEqual(pte_size, 135_000_000) # 135MB + if self.use_fp16: + self.assertLessEqual(pte_size, 275_000_000) # 275MB + else: + self.assertLessEqual(pte_size, 135_000_000) # 135MB if not self.compile_only and not self.enable_x86_64: self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai @@ -9830,6 +9835,13 @@ def setup_environment(): choices=["wikitext_ppl", "hellaswag_acc_norm", "sqnr"], type=str, ) + parser.add_argument( + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", + action="store_true", + default=False, + ) args, ns_args = parser.parse_known_args(namespace=unittest) TestQNN.host = args.host @@ -9858,6 +9870,7 @@ def setup_environment(): TestQNN.backend = args.backend TestQNN.static_llm_eval_method = args.static_llm_eval_method TestQNN.direct_build_folder = args.direct_build_folder + TestQNN.use_fp16 = args.use_fp16 return sys.argv[:1] + ns_args diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index c25e1bf789d..6a2532a9e74 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -217,6 +217,7 @@ class TestQNN(unittest.TestCase): inference_speed_output_path = "outputs/inference_speed.txt" static_llm_eval_method = "" direct_build_folder: str = "" + use_fp16 = False @classmethod def setUpClass(cls): diff --git a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py index 7bebf513658..a75e67933e5 100644 --- a/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py +++ b/examples/qualcomm/oss_scripts/llama/decoder_runtime_evaluator.py @@ -133,7 +133,7 @@ def _init_runner_base_cmd(self): base_cmd = " ".join( [ f"export LD_LIBRARY_PATH={self.qnn_sdk}/lib/x86_64-linux-clang/:{args.build_folder}/lib &&", - f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/{self.runner}", + f"{args.build_folder}/examples/qualcomm/oss_scripts/llama/{self.runner}", f"--decoder_model_version {DECODER_MODEL_VERSION[args.decoder_model]}", f"--tokenizer_path {self.runtime_tokenizer_path}", f"--output_path {self.device_output_response_path}", diff --git a/examples/qualcomm/oss_scripts/llama/llama.py b/examples/qualcomm/oss_scripts/llama/llama.py index a8e28f96b71..2b5befb1711 100755 --- a/examples/qualcomm/oss_scripts/llama/llama.py +++ b/examples/qualcomm/oss_scripts/llama/llama.py @@ -21,6 +21,7 @@ ) from executorch.backends.qualcomm.utils.utils import ( + generate_gpu_compiler_spec, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, get_soc_to_chipset_map, @@ -119,9 +120,15 @@ def compile( # because the encoder is quite sensitive and quantization can make it harder for the model to distinguish # between images within the same conversation. to_skip = len(args.image_path) > 1 - backend_options = generate_htp_compiler_spec( - use_fp16=to_skip, - ) + if args.backend == "htp": + backend_options = generate_htp_compiler_spec( + use_fp16=to_skip, + ) + elif args.backend == "gpu": + backend_options = generate_gpu_compiler_spec() + else: + raise ValueError(f"Unsupported backend {args.backend}") + encoder_compile_specs = generate_qnn_executorch_compiler_spec( soc_model=get_soc_to_chipset_map()[args.soc_model], backend_options=backend_options, @@ -131,27 +138,40 @@ def compile( skip_quantize[modality] = to_skip compile_specs[modality] = encoder_compile_specs elif is_multimodal and modality == TOK_EMBEDDING: - backend_options = generate_htp_compiler_spec( - use_fp16=False, - # x86 emulator does not support weight sharing - use_weight_sharing=not args.enable_x86_64, - ) + if args.backend == "htp": + backend_options = generate_htp_compiler_spec( + use_fp16=False, + # x86 emulator does not support weight sharing + use_weight_sharing=not args.enable_x86_64, + ) + elif args.backend == "gpu": + backend_options = generate_gpu_compiler_spec() + else: + raise ValueError(f"Unsupported backend {args.backend}") + compile_specs[modality] = [ generate_qnn_executorch_compiler_spec( soc_model=get_soc_to_chipset_map()[args.soc_model], backend_options=backend_options, # x86 emulator does not support shared buffer shared_buffer=not args.enable_x86_64, + online_prepare=args.online_prepare, ) ] * len(TOK_EMBEDDING_GRAPH_NAMES) elif modality == TEXT_DECODER: # compile spec for text decoder - backend_options = generate_htp_compiler_spec( - use_fp16=False, - use_multi_contexts=decoder_model_config.num_sharding > 1, - # x86 emulator does not support weight sharing - use_weight_sharing=not args.enable_x86_64, - ) + if args.backend == "htp": + backend_options = generate_htp_compiler_spec( + use_fp16=args.use_fp16, + use_multi_contexts=decoder_model_config.num_sharding > 1, + # x86 emulator does not support weight sharing + use_weight_sharing=not args.enable_x86_64, + ) + elif args.backend == "gpu": + backend_options = generate_gpu_compiler_spec() + else: + raise ValueError(f"Unsupported backend {args.backend}") + skip_quantize[modality] = args.use_fp16 compile_specs[modality] = [ generate_qnn_executorch_compiler_spec( soc_model=get_soc_to_chipset_map()[args.soc_model], @@ -159,6 +179,7 @@ def compile( # x86 emulator does not support shared buffer shared_buffer=not args.enable_x86_64, use_mha2sha=True, + online_prepare=args.online_prepare, ) ] * len(DECODER_GRAPH_NAMES) @@ -529,6 +550,14 @@ def _build_parser(): help="Number of examples in few-shot context", ) + parser.add_argument( + "-F", + "--use_fp16", + help="If specified, will run in fp16 precision and discard ptq setting", + action="store_true", + default=False, + ) + parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument( @@ -592,6 +621,12 @@ def export_llama(args) -> None: pte_filename = "lookahead_llama_qnn" else: raise RuntimeError(f"Unknown model_mode: {args.model_mode}.") + + if args.model_mode == "hybrid" and args.online_prepare: + raise RuntimeError( + "Currently hybrid mode is not compatible with online_prepare." + ) + if args.decoder_model == "stories260k": pte_filename = f"{args.decoder_model}_" + pte_filename pte_filenames = { @@ -740,6 +775,7 @@ def export_llama(args) -> None: def main(): parser = _build_parser() args = parser.parse_args() + args.build_folder = os.path.realpath(args.build_folder) try: export_llama(args) except Exception as e: diff --git a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp index e9a3dcfe4e2..52a33c6984c 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp @@ -210,7 +210,6 @@ std::string get_formatted_prompt( return formatted_prompt; } -template void start_runner( std::unique_ptr module, std::vector& prompts, @@ -219,7 +218,7 @@ void start_runner( gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default ? false : true; // create llama runner - example::Runner runner( + example::Runner runner( std::move(module), FLAGS_decoder_model_version.c_str(), FLAGS_model_path.c_str(), @@ -296,26 +295,8 @@ int main(int argc, char** argv) { FLAGS_attention_sink_rope_path.c_str(), executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); } - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width").get().toScalar().to()); - } - - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - start_runner( - std::move(module), prompts, std::move(attention_sink_rope_module)); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - start_runner( - std::move(module), prompts, std::move(attention_sink_rope_module)); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } + start_runner( + std::move(module), prompts, std::move(attention_sink_rope_module)); return 0; } diff --git a/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp b/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp index 29b6b9d7ddc..c9c2bd19940 100644 --- a/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp @@ -137,7 +137,6 @@ std::vector CollectPrompts(int argc, char** argv) { return prompts; } -template void start_multimodal_runner( std::unique_ptr encoder, std::unique_ptr tok_embedding, @@ -150,7 +149,7 @@ void start_multimodal_runner( : true; // Create multimodal runner - example::QNNMultimodalRunner runner( + example::QNNMultimodalRunner runner( std::move(encoder), std::move(tok_embedding), std::move(text_decoder), @@ -289,35 +288,12 @@ int main(int argc, char** argv) { FLAGS_decoder_path.c_str(), executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors); - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (text_decoder->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - text_decoder->get("get_kv_io_bit_width") - .get() - .toScalar() - .to()); - } - // Start runner with appropriate KV bitwidth - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - start_multimodal_runner( - std::move(encoder), - std::move(tok_embedding), - std::move(text_decoder), - prompts); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - start_multimodal_runner( - std::move(encoder), - std::move(tok_embedding), - std::move(text_decoder), - prompts); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } + // Start runner + start_multimodal_runner( + std::move(encoder), + std::move(tok_embedding), + std::move(text_decoder), + prompts); return 0; } diff --git a/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h b/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h index 888e9acd421..b714f737de3 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h @@ -8,6 +8,7 @@ #pragma once +#include #include #include #include @@ -56,19 +57,36 @@ class DecoderRunner { inline int32_t logits_to_token( const executorch::aten::Tensor& logits_tensor, int64_t pos) { - auto* logits = logits_tensor.mutable_data_ptr(); + std::byte* logits = logits_tensor.mutable_data_ptr(); auto num_tokens = logits_tensor.size(1); auto vocab_size = logits_tensor.size(2); static std::vector logits_f(vocab_size); - auto* logits_last = logits; + std::byte* logits_last = logits; // offset to the meaningful logit we want for prefill model. + executorch::aten::ScalarType logits_dtype = logits_tensor.scalar_type(); + size_t logits_nbytes = getDtypeSize(logits_dtype); if (num_tokens > 1) { - logits_last += pos * vocab_size; + logits_last += pos * vocab_size * logits_nbytes; } - // Discard dequantization (converting uint16_t to float) because the + // Discard dequantization (converting std::byte to float) because the // relative order of elements remains the same without conversion for (int i = 0; i < vocab_size; i++) { - logits_f[i] = logits_last[i]; + switch (logits_dtype) { + case executorch::aten::ScalarType::UInt16: + logits_f[i] = reinterpret_cast(logits_last)[i]; + break; + case executorch::aten::ScalarType::Byte: + logits_f[i] = reinterpret_cast(logits_last)[i]; + break; + case executorch::aten::ScalarType::Float: + logits_f[i] = reinterpret_cast(logits_last)[i]; + break; + default: + ET_CHECK_MSG( + false, + "The scalar_type %s of logits is not supported", + executorch::runtime::toString(logits_dtype)); + } } return sampler_->sample(logits_f.data()); } diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp index e5c12068bab..7288ca5fbd1 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.cpp @@ -7,24 +7,105 @@ */ #include +#include #include + +using executorch::runtime::MethodMeta; +using executorch::runtime::Result; +using executorch::runtime::TensorInfo; namespace example { -template -KVManager::KVManager(Metadata metadata) : metadata_(metadata) { + +namespace { +void fill_mask( + executorch::aten::ScalarType scalar_type, + std::byte* buf, + size_t size, + bool use_pos_value) { + if (use_pos_value) { + switch (scalar_type) { + case executorch::aten::ScalarType::UInt16: + std::fill_n(reinterpret_cast(buf), size, 65535u); + break; + case executorch::aten::ScalarType::Byte: + std::fill_n(reinterpret_cast(buf), size, 255u); + break; + case executorch::aten::ScalarType::Float: + std::fill_n(reinterpret_cast(buf), size, 0.0); + break; + default: + ET_CHECK_MSG( + false, + "Unsupported scalar type %s", + executorch::runtime::toString(scalar_type)); + break; + } + } else { + switch (scalar_type) { + case executorch::aten::ScalarType::UInt16: + std::fill_n(reinterpret_cast(buf), size, 0u); + break; + case executorch::aten::ScalarType::Byte: + std::fill_n(reinterpret_cast(buf), size, 0u); + break; + // -65535 acts as the additive "very negative" attention-mask value; + // chosen as a large finite negative so masked positions effectively + // zero out after softmax without relying on -inf. + case executorch::aten::ScalarType::Float: + std::fill_n(reinterpret_cast(buf), size, -65535.0); + break; + default: + ET_CHECK_MSG( + false, + "Unsupported scalar type %s", + executorch::runtime::toString(scalar_type)); + break; + } + } +} +} // namespace + +KVManager::KVManager(Metadata metadata, std::unique_ptr method_meta) + : metadata_(metadata) { + Result attention_mask = method_meta->input_tensor_meta(1); + attention_mask_dtype_ = attention_mask->scalar_type(); + + // inputs are [input_tokens, attention_mask, (sliding window attention_mask), + // (input_pos), kv_caches] search kv_cache in inputs + for (int i = 2; i < method_meta->num_inputs(); i++) { + Result tensor_meta = method_meta->input_tensor_meta(i); + // k_cache: [1, n_heads, head_dim, seq_len] + size_t tensor_nbytes = tensor_meta->nbytes(); + size_t expected_tensor_nbytes = metadata_.head_dim * metadata_.num_heads * + metadata_.max_cache_len * getDtypeSize(tensor_meta->scalar_type()); + if (tensor_nbytes != expected_tensor_nbytes) { + // Not a kv_cache tensor (e.g. input_pos, sliding window attention mask). + continue; + } + if (kv_cache_dtype_ == executorch::aten::ScalarType::Undefined) { + kv_cache_dtype_ = tensor_meta->scalar_type(); + } else { + ET_CHECK_MSG( + tensor_meta->scalar_type() == kv_cache_dtype_, + "Currently mixed scalar type of kv_cache is not allowed"); + } + } + ET_CHECK_MSG( + kv_cache_dtype_ != executorch::aten::ScalarType::Undefined, + "kv_cache_dtype was not detected from method inputs"); k_cache_.resize(metadata_.num_layers); v_cache_.resize(metadata_.num_layers); // Calculate cache size size_t cache_in_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_cache_len * sizeof(T); + metadata_.head_dim * metadata_.max_cache_len * + getDtypeSize(kv_cache_dtype_); size_t cache_out_bytes = metadata_.num_layers * metadata_.num_heads * - metadata_.head_dim * metadata_.max_ar_len * sizeof(T); + metadata_.head_dim * metadata_.max_ar_len * getDtypeSize(kv_cache_dtype_); total_cache_size_ = 2 * (cache_in_bytes + cache_out_bytes); }; -template -void KVManager::init_attention_mask( - uint16_t* attention_mask, +void KVManager::init_attention_mask( + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past) { @@ -33,38 +114,51 @@ void KVManager::init_attention_mask( "The size of attention_map (%zu) doesn't match with ar_len (%d)", attention_map.size(), ar_len); - uint16_t neg_val = 0; - uint16_t pos_val = 65535; // Clear the attention mask - std::fill_n(attention_mask, ar_len * metadata_.context_len, neg_val); + fill_mask( + attention_mask_dtype_, + attention_mask, + ar_len * metadata_.context_len, + /*use_pos_value=*/false); // SMART_MASK requires special handling of attention mask - uint16_t* past_ptr = attention_mask; - uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); + std::byte* past_ptr = attention_mask; + std::byte* new_ptr = attention_mask + + (metadata_.context_len - ar_len) * getDtypeSize(attention_mask_dtype_); // All inputs will necessarily attend to n_past and itself for (int i = 0; i < ar_len; i++) { // Iterate across ar_len if (attention_map[i] < 0) { // If negative, attend to only past tokens - std::fill_n(past_ptr, n_past, pos_val); + fill_mask( + attention_mask_dtype_, + past_ptr, + n_past, + /*use_pos_value=*/true); } else { // If positive, copy attention map from (relative to 0th input) parent // Parent token index const int32_t pidx = attention_map[i]; - uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; + std::byte* parent_ptr = attention_mask + + pidx * metadata_.context_len * getDtypeSize(attention_mask_dtype_); std::memcpy( - past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); + past_ptr, + parent_ptr, + metadata_.context_len * getDtypeSize(attention_mask_dtype_)); } // Attend to itself - new_ptr[i] = pos_val; - past_ptr += metadata_.context_len; - new_ptr += metadata_.context_len; + fill_mask( + attention_mask_dtype_, + new_ptr + i * getDtypeSize(attention_mask_dtype_), + 1, + /*use_pos_value=*/true); + past_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); + new_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::init_attention_mask( - uint16_t* attention_mask, +void KVManager::init_attention_mask( + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past, @@ -75,30 +169,44 @@ void KVManager::init_attention_mask( "The size of attention_map (%zu) doesn't match with ar_len (%d)", attention_map.size(), ar_len); - uint16_t neg_val = 0; - uint16_t pos_val = 65535; // Clear the attention mask - std::fill_n(attention_mask, ar_len * metadata_.context_len, neg_val); + fill_mask( + attention_mask_dtype_, + attention_mask, + ar_len * metadata_.context_len, + /*use_pos_value=*/false); // SMART_MASK requires special handling of attention mask - uint16_t* past_ptr = attention_mask; - uint16_t* new_ptr = attention_mask + (metadata_.context_len - ar_len); + std::byte* past_ptr = attention_mask; + std::byte* new_ptr = attention_mask + + (metadata_.context_len - ar_len) * getDtypeSize(attention_mask_dtype_); // All inputs will necessarily attend to n_past and itself for (int i = 0; i < ar_len; i++) { // Iterate across ar_len if (attention_map[i] < 0) { // If negative, attend to only past tokens - std::fill_n(past_ptr, n_past, pos_val); + fill_mask( + attention_mask_dtype_, + past_ptr, + n_past, + /*use_pos_value=*/true); } else { // If positive, copy attention map from (relative to 0th input) parent // Parent token index const int32_t pidx = attention_map[i]; - uint16_t* parent_ptr = attention_mask + pidx * metadata_.context_len; + std::byte* parent_ptr = attention_mask + + pidx * metadata_.context_len * getDtypeSize(attention_mask_dtype_); std::memcpy( - past_ptr, parent_ptr, metadata_.context_len * sizeof(uint16_t)); + past_ptr, + parent_ptr, + metadata_.context_len * getDtypeSize(attention_mask_dtype_)); } // Attend to itself - new_ptr[i] = pos_val; + fill_mask( + attention_mask_dtype_, + new_ptr + i * getDtypeSize(attention_mask_dtype_), + 1, + /*use_pos_value=*/true); // mask by limitation of sliding_window int32_t available_context_len = position_offset.empty() @@ -107,87 +215,73 @@ void KVManager::init_attention_mask( // if available_context_len is less than 0, it means we need to mask some // tokens in the past to avoid exceeding the sliding window if (available_context_len < 0) { - std::fill_n(past_ptr, -available_context_len, neg_val); + fill_mask( + attention_mask_dtype_, + past_ptr, + -available_context_len, + /*use_pos_value=*/false); } - past_ptr += metadata_.context_len; - new_ptr += metadata_.context_len; + past_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); + new_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::update_attention_mask( - uint16_t* attention_mask, +void KVManager::update_attention_mask( + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update) { - uint16_t pos_val = 65535; - uint16_t* cur_ptr = attention_mask; - cur_ptr += n_past; + std::byte* cur_ptr = + attention_mask + n_past * getDtypeSize(attention_mask_dtype_); for (int i = 0; i < ar_len; i++) { - std::fill_n(cur_ptr, n_update, pos_val); - cur_ptr += metadata_.context_len; + fill_mask(attention_mask_dtype_, cur_ptr, n_update, /*use_pos_value=*/true); + cur_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::update_attention_mask( - uint16_t* attention_mask, +void KVManager::update_attention_mask( + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update, int32_t sliding_window, const std::vector& position_offset) { - uint16_t pos_val = 65535; - uint16_t neg_val = 0; - uint16_t* cur_ptr = attention_mask; - cur_ptr += n_past; + std::byte* cur_ptr = + attention_mask + n_past * getDtypeSize(attention_mask_dtype_); for (int i = 0; i < ar_len; i++) { - std::fill_n(cur_ptr, n_update, pos_val); + fill_mask(attention_mask_dtype_, cur_ptr, n_update, /*use_pos_value=*/true); int32_t available_cache_len = position_offset.empty() ? sliding_window - (i + 1) : sliding_window - (position_offset[i] + 1); if (n_past + n_update > available_cache_len) { - std::fill_n( - cur_ptr - n_past, n_past + n_update - available_cache_len, neg_val); + fill_mask( + attention_mask_dtype_, + cur_ptr - n_past * getDtypeSize(attention_mask_dtype_), + n_past + n_update, + /*use_pos_value=*/false); } - cur_ptr += metadata_.context_len; + cur_ptr += metadata_.context_len * getDtypeSize(attention_mask_dtype_); } } -template -void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { +void KVManager::init_cache(IMemAlloc* buffer_manager, int32_t ar_len) { cur_ar_len_ = ar_len; - const size_t max_in_cache_block_in_bytes = - metadata_.max_cache_len * sizeof(T); - const size_t max_out_cache_block_in_bytes = metadata_.max_ar_len * sizeof(T); - - const size_t cache_in_bytes = - metadata_.num_heads * metadata_.head_dim * max_in_cache_block_in_bytes; - const size_t cache_out_bytes = - metadata_.num_heads * metadata_.head_dim * max_out_cache_block_in_bytes; + const size_t cache_in_bytes = metadata_.num_heads * metadata_.head_dim * + metadata_.max_cache_len * getDtypeSize(kv_cache_dtype_); + const size_t cache_out_bytes = metadata_.num_heads * metadata_.head_dim * + metadata_.max_ar_len * getDtypeSize(kv_cache_dtype_); for (int layer = 0; layer < metadata_.num_layers; ++layer) { - // Allocate buffer for key cache and value cache - T* single_layer_k_cache_in = - reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); - T* single_layer_k_cache_out = - reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); - T* single_layer_v_cache_in = - reinterpret_cast(buffer_manager->allocate(cache_in_bytes)); - T* single_layer_v_cache_out = - reinterpret_cast(buffer_manager->allocate(cache_out_bytes)); - - k_cache_[layer].buffer = single_layer_k_cache_in; - k_cache_[layer].output_buffer = single_layer_k_cache_out; - v_cache_[layer].buffer = single_layer_v_cache_in; - v_cache_[layer].output_buffer = single_layer_v_cache_out; + k_cache_[layer].buffer = buffer_manager->allocate(cache_in_bytes); + k_cache_[layer].output_buffer = buffer_manager->allocate(cache_out_bytes); + v_cache_[layer].buffer = buffer_manager->allocate(cache_in_bytes); + v_cache_[layer].output_buffer = buffer_manager->allocate(cache_out_bytes); } } -template -void KVManager::rearrange_cache(int32_t ar_len_dst) { +void KVManager::rearrange_cache(int32_t ar_len_dst) { // Don't need to rearrange if cur_ar_len_ is equal to target ar_len if (cur_ar_len_ == ar_len_dst) return; @@ -199,75 +293,73 @@ void KVManager::rearrange_cache(int32_t ar_len_dst) { cur_ar_len_ = ar_len_dst; } -template -void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { +void KVManager::rearrange_key(KVCache& k_cache, int32_t ar_len_dst) { const int32_t src_cache_num = (cur_ar_len_ == metadata_.context_len) ? metadata_.context_len : metadata_.context_len - cur_ar_len_; const int32_t dst_cache_num = metadata_.context_len - ar_len_dst; - T* k_cache_in_read_ptr = k_cache.buffer; - T* k_cache_in_write_ptr = k_cache.buffer; - + std::byte* k_cache_in_read_ptr = k_cache.buffer; + std::byte* k_cache_in_write_ptr = k_cache.buffer; + size_t src_cache_nbytes = src_cache_num * getDtypeSize(kv_cache_dtype_); + size_t dst_cache_nbytes = dst_cache_num * getDtypeSize(kv_cache_dtype_); if (src_cache_num > dst_cache_num) { // copy from first dimension for (int i = 0; i < metadata_.head_dim * metadata_.num_heads; i++) { - std::memmove( - k_cache_in_write_ptr, k_cache_in_read_ptr, dst_cache_num * sizeof(T)); - k_cache_in_read_ptr += src_cache_num; - k_cache_in_write_ptr += dst_cache_num; + std::memmove(k_cache_in_write_ptr, k_cache_in_read_ptr, dst_cache_nbytes); + k_cache_in_read_ptr += src_cache_nbytes; + k_cache_in_write_ptr += dst_cache_nbytes; } } else { k_cache_in_read_ptr += - (metadata_.head_dim * metadata_.num_heads - 1) * src_cache_num; + (metadata_.head_dim * metadata_.num_heads - 1) * src_cache_nbytes; k_cache_in_write_ptr += - (metadata_.head_dim * metadata_.num_heads - 1) * dst_cache_num; + (metadata_.head_dim * metadata_.num_heads - 1) * dst_cache_nbytes; // copy from last dimension for (int i = 0; i < metadata_.head_dim * metadata_.num_heads; i++) { - std::memmove( - k_cache_in_write_ptr, k_cache_in_read_ptr, src_cache_num * sizeof(T)); - k_cache_in_read_ptr -= src_cache_num; - k_cache_in_write_ptr -= dst_cache_num; + std::memmove(k_cache_in_write_ptr, k_cache_in_read_ptr, src_cache_nbytes); + k_cache_in_read_ptr -= src_cache_nbytes; + k_cache_in_write_ptr -= dst_cache_nbytes; } } } -template -void KVManager::rearrange_value(KVCache& v_cache, int32_t ar_len_dst) { +void KVManager::rearrange_value(KVCache& v_cache, int32_t ar_len_dst) { const int32_t src_cache_num = (cur_ar_len_ == metadata_.context_len) ? metadata_.context_len : metadata_.context_len - cur_ar_len_; const int32_t dst_cache_num = metadata_.context_len - ar_len_dst; - T* v_cache_in_read_ptr = v_cache.buffer; - T* v_cache_in_write_ptr = v_cache.buffer; + std::byte* v_cache_in_read_ptr = v_cache.buffer; + std::byte* v_cache_in_write_ptr = v_cache.buffer; + size_t src_cache_nbytes = src_cache_num * getDtypeSize(kv_cache_dtype_); + size_t dst_cache_nbytes = dst_cache_num * getDtypeSize(kv_cache_dtype_); if (src_cache_num > dst_cache_num) { // copy from first dimension for (int i = 0; i < metadata_.num_heads; i++) { std::memmove( v_cache_in_write_ptr, v_cache_in_read_ptr, - dst_cache_num * metadata_.head_dim * sizeof(T)); - v_cache_in_read_ptr += src_cache_num * metadata_.head_dim; - v_cache_in_write_ptr += dst_cache_num * metadata_.head_dim; + dst_cache_nbytes * metadata_.head_dim); + v_cache_in_read_ptr += src_cache_nbytes * metadata_.head_dim; + v_cache_in_write_ptr += dst_cache_nbytes * metadata_.head_dim; } } else { v_cache_in_read_ptr += - metadata_.head_dim * (metadata_.num_heads - 1) * src_cache_num; + metadata_.head_dim * (metadata_.num_heads - 1) * src_cache_nbytes; v_cache_in_write_ptr += - metadata_.head_dim * (metadata_.num_heads - 1) * dst_cache_num; + metadata_.head_dim * (metadata_.num_heads - 1) * dst_cache_nbytes; // copy from last dimension for (int i = 0; i < metadata_.num_heads; i++) { std::memmove( v_cache_in_write_ptr, v_cache_in_read_ptr, - src_cache_num * metadata_.head_dim * sizeof(T)); - v_cache_in_read_ptr -= src_cache_num * metadata_.head_dim; - v_cache_in_write_ptr -= dst_cache_num * metadata_.head_dim; + src_cache_nbytes * metadata_.head_dim); + v_cache_in_read_ptr -= src_cache_nbytes * metadata_.head_dim; + v_cache_in_write_ptr -= dst_cache_nbytes * metadata_.head_dim; } } } -template -void KVManager::update_cache( +void KVManager::update_cache( int32_t ar_len, int32_t n_past, int32_t n_update, @@ -283,20 +375,19 @@ void KVManager::update_cache( } } -template -void KVManager::update_key( - KVCache& k_cache, +void KVManager::update_key( + KVCache& k_cache, int32_t n_past, int32_t n_update, const std::vector& selected) { - T* write_ptr = k_cache.buffer; - T* read_ptr = k_cache.output_buffer; - const int32_t copy_size = n_update * sizeof(T); + std::byte* write_ptr = k_cache.buffer; + std::byte* read_ptr = k_cache.output_buffer; + const int32_t copy_size = n_update * getDtypeSize(kv_cache_dtype_); const int32_t iter_size = (cur_ar_len_ == metadata_.context_len) - ? metadata_.context_len - : metadata_.context_len - cur_ar_len_; - const int32_t out_size = cur_ar_len_; - const int32_t past_size = n_past; + ? metadata_.context_len * getDtypeSize(kv_cache_dtype_) + : (metadata_.context_len - cur_ar_len_) * getDtypeSize(kv_cache_dtype_); + const int32_t out_size = cur_ar_len_ * getDtypeSize(kv_cache_dtype_); + const int32_t past_size = n_past * getDtypeSize(kv_cache_dtype_); const int32_t n_iter = metadata_.head_dim * metadata_.num_heads; write_ptr += past_size; @@ -316,7 +407,11 @@ void KVManager::update_key( for (int i = 0; i < n_iter; ++i) { auto wp = write_ptr, rp = read_ptr; for (auto ind : true_indices) { - *wp++ = rp[ind]; + std::memmove( + wp, + rp + ind * getDtypeSize(kv_cache_dtype_), + getDtypeSize(kv_cache_dtype_)); + wp += getDtypeSize(kv_cache_dtype_); } write_ptr += iter_size; read_ptr += out_size; @@ -324,21 +419,25 @@ void KVManager::update_key( } } -template -void KVManager::update_value( - KVCache& v_cache, +void KVManager::update_value( + KVCache& v_cache, int32_t n_past, int32_t n_update, const std::vector& selected) { - T* write_ptr = v_cache.buffer; - T* read_ptr = v_cache.output_buffer; - const int32_t copy_size = n_update * metadata_.head_dim * sizeof(T); - const int32_t past_size = n_past * metadata_.head_dim; + std::byte* write_ptr = v_cache.buffer; + std::byte* read_ptr = v_cache.output_buffer; + const int32_t copy_size = + n_update * metadata_.head_dim * getDtypeSize(kv_cache_dtype_); + const int32_t past_size = + n_past * metadata_.head_dim * getDtypeSize(kv_cache_dtype_); const int32_t n_iter = metadata_.num_heads; const int32_t iter_size = (cur_ar_len_ == metadata_.context_len) - ? metadata_.context_len * metadata_.head_dim - : (metadata_.context_len - cur_ar_len_) * metadata_.head_dim; - const int32_t out_size = cur_ar_len_ * metadata_.head_dim; + ? metadata_.context_len * metadata_.head_dim * + getDtypeSize(kv_cache_dtype_) + : (metadata_.context_len - cur_ar_len_) * metadata_.head_dim * + getDtypeSize(kv_cache_dtype_); + const int32_t out_size = + cur_ar_len_ * metadata_.head_dim * getDtypeSize(kv_cache_dtype_); write_ptr += past_size; @@ -354,13 +453,14 @@ void KVManager::update_value( auto wp = write_ptr, rp = read_ptr; for (auto sel : selected) { if (sel) { - std::memcpy(wp, rp, metadata_.head_dim * sizeof(T)); - wp += metadata_.head_dim; + std::memcpy( + wp, rp, metadata_.head_dim * getDtypeSize(kv_cache_dtype_)); + wp += metadata_.head_dim * getDtypeSize(kv_cache_dtype_); update_times--; if (update_times == 0) break; } - rp += metadata_.head_dim; + rp += metadata_.head_dim * getDtypeSize(kv_cache_dtype_); } write_ptr += iter_size; read_ptr += out_size; @@ -368,8 +468,4 @@ void KVManager::update_value( } } -// Explicit instantiations -template class KVManager; -template class KVManager; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h index 06fe88517a7..3b8e67dd38d 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h +++ b/examples/qualcomm/oss_scripts/llama/runner/kv_manager.h @@ -8,6 +8,7 @@ #pragma once #include +#include #include #include #include @@ -15,17 +16,15 @@ namespace example { // Structure to hold key-value cache buffers -template struct KVCache { - T* buffer; - T* output_buffer; + std::byte* buffer; + std::byte* output_buffer; }; /** * @class KVManager * @brief Class for kv cache update, rearrangement, and buffer allocatation. */ -template class KVManager { public: struct Metadata { @@ -36,7 +35,9 @@ class KVManager { int64_t num_heads; int64_t num_layers; }; - KVManager(Metadata metadata); + KVManager( + Metadata metadata, + std::unique_ptr method_meta); /** * @brief Allocate buffer for KV cache and set the cur_ar_len_. @@ -71,7 +72,7 @@ class KVManager { * @param n_past Number of past elements in the cache. */ void init_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past); @@ -98,7 +99,7 @@ class KVManager { * @param position_offset (optional) attention mask position offset of */ void init_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, const std::vector& attention_map, int32_t ar_len, int32_t n_past, @@ -114,7 +115,7 @@ class KVManager { * @param n_update Number of elements to be updated. */ void update_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update); @@ -132,7 +133,7 @@ class KVManager { * lookahead decoder */ void update_attention_mask( - uint16_t* attention_mask, + std::byte* attention_mask, int32_t ar_len, int32_t n_past, int32_t n_update, @@ -152,10 +153,10 @@ class KVManager { int32_t n_update, const std::vector& selected); - const std::vector>& get_k_cache_() const { + const std::vector& get_k_cache_() const { return k_cache_; } - const std::vector>& get_v_cache_() const { + const std::vector& get_v_cache_() const { return v_cache_; } @@ -169,15 +170,19 @@ class KVManager { private: // Helper functions to rearrange and update key and value caches - void rearrange_key(KVCache& k_cache, int32_t ar_len_dst); - void rearrange_value(KVCache& v_cache, int32_t ar_len_dst); + + void rearrange_key(KVCache& k_cache, int32_t ar_len_dst); + + void rearrange_value(KVCache& v_cache, int32_t ar_len_dst); + void update_key( - KVCache& k_cache, + KVCache& k_cache, int32_t n_past, int32_t n_update, const std::vector& selected); + void update_value( - KVCache& v_cache, + KVCache& v_cache, int32_t n_past, int32_t n_update, const std::vector& selected); @@ -186,10 +191,14 @@ class KVManager { Metadata metadata_; size_t total_cache_size_; int32_t cur_ar_len_; + executorch::aten::ScalarType attention_mask_dtype_ = + executorch::aten::ScalarType::Undefined; + executorch::aten::ScalarType kv_cache_dtype_ = + executorch::aten::ScalarType::Undefined; // Store start pointer of k and v cache for input and output // input: layer -> head * head_dim * max_cache_len // output: layer -> head * head_dim * max_ar_len - std::vector> k_cache_; - std::vector> v_cache_; + std::vector k_cache_; + std::vector v_cache_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp index f7e44292f26..298fc1ac9ff 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.cpp @@ -13,20 +13,19 @@ using executorch::runtime::Result; namespace example { -template -void LhdTokenGenerator::prepare_io( +void LhdTokenGenerator::prepare_io( std::vector input_tokens, std::vector input_pos) { for (int i = 0; i < metadata_.ar_len; i++) { if (i < input_tokens.size()) { // Prepare pos data - this->input_pos_.data[i] = input_pos[i]; + reinterpret_cast(this->input_pos_.data)[i] = input_pos[i]; // Support CPU 4-bit embedding, which requires int64 input. // However, for QNN embedding, only int32 input is needed. // Therefore, we need to cast to the correct type to write the data. if (metadata_.use_int64_token) { - this->input_toks_.data[i] = input_tokens[i]; + reinterpret_cast(this->input_toks_.data)[i] = input_tokens[i]; } else { int32_t* input_toks_ptr = reinterpret_cast(this->input_toks_.data); @@ -36,8 +35,7 @@ void LhdTokenGenerator::prepare_io( } } -template -void LhdTokenGenerator::init_attention_mask(int32_t n_past) { +void LhdTokenGenerator::init_attention_mask(int32_t n_past) { std::vector attention_map; attention_map.reserve(metadata_.ar_len); // Initialize attention mask with current position @@ -73,8 +71,7 @@ void LhdTokenGenerator::init_attention_mask(int32_t n_past) { } } -template -void LhdTokenGenerator::init_lookahead_branch( +void LhdTokenGenerator::init_lookahead_branch( const std::vector& tokens) { for (int i = 0; i < metadata_.ngram - 1; ++i) { for (int j = 0; j < metadata_.window; ++j) { @@ -91,8 +88,7 @@ void LhdTokenGenerator::init_lookahead_branch( is_lhd_branch_initialized_ = true; } -template -void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { +void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { const int g_cur = ngrams_pool_.cnt[cur_token]; v_branch_.resize(g_cur); @@ -116,8 +112,7 @@ void LhdTokenGenerator::init_verification_branch(uint64_t cur_token) { } } -template -void LhdTokenGenerator::update_ngrams_pool() { +void LhdTokenGenerator::update_ngrams_pool() { std::vector ngram(metadata_.ngram - 1); // n-gram pool generation for (int f = 0; f < metadata_.window; ++f) { @@ -170,8 +165,7 @@ void LhdTokenGenerator::update_ngrams_pool() { } } -template -void LhdTokenGenerator::update_lookahead_branch( +void LhdTokenGenerator::update_lookahead_branch( const executorch::aten::Tensor& logits_tensor) { for (int i = 0; i < metadata_.window; i++) { lhd_branch_prev_[i] = lhd_branch_[0][i]; @@ -189,8 +183,7 @@ void LhdTokenGenerator::update_lookahead_branch( } } -template -Result LhdTokenGenerator::generate( +Result LhdTokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -427,8 +420,4 @@ Result LhdTokenGenerator::generate( return pos - start_pos; } -// Explicit instantiations -template class LhdTokenGenerator; -template class LhdTokenGenerator; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h index 796dde88014..8fdffb8af72 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/lhd_token_generator.h @@ -15,8 +15,8 @@ namespace example { * @brief Class for generating the token using decoder and key-value manager * with lookahead decoding. */ -template -class LhdTokenGenerator : public TokenGenerator { + +class LhdTokenGenerator : public TokenGenerator { public: struct Metadata { int32_t context_len; @@ -34,18 +34,19 @@ class LhdTokenGenerator : public TokenGenerator { LhdTokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& forward_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) - : TokenGenerator( + executorch::llm::Stats* stats, + std::unique_ptr method_meta) + : TokenGenerator( tokenizer, decoder_runner, kv_manager, forward_name, std::move(eos_ids), - typename TokenGenerator::Metadata{ + TokenGenerator::Metadata{ metadata.context_len, metadata.num_heads, metadata.num_layers, @@ -54,7 +55,8 @@ class LhdTokenGenerator : public TokenGenerator { metadata.use_int64_token, metadata.sliding_window, metadata.cache_mode}, - stats), + stats, + std::move(method_meta)), metadata_(metadata), lhd_branch_(metadata.ngram - 1, std::vector(metadata.window)), lhd_branch_prev_(metadata.window), @@ -104,7 +106,7 @@ class LhdTokenGenerator : public TokenGenerator { private: // Bring base class's virtual prepare_io into scope so the overload below // does not hide it (-Woverloaded-virtual). - using TokenGenerator::prepare_io; + using TokenGenerator::prepare_io; /** * @brief Fill in I/O buffers with prompt token and position. * @param cur_token Current token. diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp index 14a93104e1a..de8d1bea0fe 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.cpp @@ -13,8 +13,7 @@ using executorch::runtime::Result; namespace example { -template -void MultimodalLhdTokenGenerator::prepare_io( +void MultimodalLhdTokenGenerator::prepare_io( std::vector input_tokens, std::vector input_pos) { for (int i = 0; i < metadata_.ar_len; i++) { @@ -51,8 +50,7 @@ void MultimodalLhdTokenGenerator::prepare_io( } } -template -void MultimodalLhdTokenGenerator::init_attention_mask(int32_t n_past) { +void MultimodalLhdTokenGenerator::init_attention_mask(int32_t n_past) { std::vector attention_map; attention_map.reserve(metadata_.ar_len); // Initialize attention mask with current position @@ -88,8 +86,7 @@ void MultimodalLhdTokenGenerator::init_attention_mask(int32_t n_past) { } } -template -void MultimodalLhdTokenGenerator::init_lookahead_branch( +void MultimodalLhdTokenGenerator::init_lookahead_branch( const std::vector& tokens) { for (int i = 0; i < metadata_.ngram - 1; ++i) { for (int j = 0; j < metadata_.window; ++j) { @@ -106,9 +103,7 @@ void MultimodalLhdTokenGenerator::init_lookahead_branch( is_lhd_branch_initialized_ = true; } -template -void MultimodalLhdTokenGenerator::init_verification_branch( - uint64_t cur_token) { +void MultimodalLhdTokenGenerator::init_verification_branch(uint64_t cur_token) { const int g_cur = ngrams_pool_.cnt[cur_token]; v_branch_.resize(g_cur); @@ -132,8 +127,7 @@ void MultimodalLhdTokenGenerator::init_verification_branch( } } -template -void MultimodalLhdTokenGenerator::update_ngrams_pool() { +void MultimodalLhdTokenGenerator::update_ngrams_pool() { std::vector ngram(metadata_.ngram - 1); // n-gram pool generation for (int f = 0; f < metadata_.window; ++f) { @@ -186,8 +180,7 @@ void MultimodalLhdTokenGenerator::update_ngrams_pool() { } } -template -void MultimodalLhdTokenGenerator::update_lookahead_branch( +void MultimodalLhdTokenGenerator::update_lookahead_branch( const executorch::aten::Tensor& logits_tensor) { for (int i = 0; i < metadata_.window; i++) { lhd_branch_prev_[i] = lhd_branch_[0][i]; @@ -205,8 +198,7 @@ void MultimodalLhdTokenGenerator::update_lookahead_branch( } } -template -Result MultimodalLhdTokenGenerator::generate( +Result MultimodalLhdTokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -412,8 +404,4 @@ Result MultimodalLhdTokenGenerator::generate( return pos - start_pos; } -// Explicit instantiations -template class MultimodalLhdTokenGenerator; -template class MultimodalLhdTokenGenerator; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h index 7494afec6da..6ffe285e536 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_lhd_token_generator.h @@ -15,9 +15,7 @@ namespace example { * @class MultimodalLhdTokenGenerator * @brief Extended LhdTokenGenerator with multimodal embedding support */ -template -class MultimodalLhdTokenGenerator - : public example::MultimodalTokenGenerator { +class MultimodalLhdTokenGenerator : public example::MultimodalTokenGenerator { public: struct Metadata { int32_t context_len; @@ -37,19 +35,20 @@ class MultimodalLhdTokenGenerator tokenizers::Tokenizer* tokenizer, TokenEmbeddingProcessor* embedding_runner, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& forward_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) - : MultimodalTokenGenerator( + executorch::llm::Stats* stats, + std::unique_ptr method_meta) + : MultimodalTokenGenerator( tokenizer, embedding_runner, decoder_runner, kv_manager, forward_name, std::move(eos_ids), - typename MultimodalTokenGenerator::Metadata{ + MultimodalTokenGenerator::Metadata{ metadata.context_len, metadata.num_heads, metadata.num_layers, @@ -59,7 +58,8 @@ class MultimodalLhdTokenGenerator metadata.sliding_window, metadata.cache_mode, metadata.embedding_dim}, - stats), + stats, + std::move(method_meta)), tok_embedding_runner_(embedding_runner), metadata_(metadata), lhd_branch_(metadata.ngram - 1, std::vector(metadata.window)), @@ -110,7 +110,7 @@ class MultimodalLhdTokenGenerator private: // Bring base class's virtual prepare_io into scope so the overload below // does not hide it (-Woverloaded-virtual). - using TokenGenerator::prepare_io; + using TokenGenerator::prepare_io; /** * @brief Fill in I/O buffers with prompt token and position. * @param cur_token Current token. diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp index 2859e16a42a..f63a431791b 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.cpp @@ -16,13 +16,13 @@ using executorch::runtime::TensorInfo; namespace example { -template -MultimodalPromptProcessor::MultimodalPromptProcessor( +MultimodalPromptProcessor::MultimodalPromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata) - : PromptProcessor( + Metadata metadata, + std::unique_ptr method_meta) + : PromptProcessor( decoder_runner, kv_manager, method_name, @@ -33,7 +33,8 @@ MultimodalPromptProcessor::MultimodalPromptProcessor( metadata.vocab_size, metadata.use_int64_token, metadata.sliding_window, - metadata.cache_mode}), + metadata.cache_mode}, + std::move(method_meta)), metadata_(metadata) { // Set input_toks_.size to 0 since we use embeddings instead input_toks_.size = 0; @@ -41,8 +42,7 @@ MultimodalPromptProcessor::MultimodalPromptProcessor( metadata_.ar_len * metadata_.embedding_dim * sizeof(float); }; -template -void MultimodalPromptProcessor::init_io( +void MultimodalPromptProcessor::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -66,8 +66,7 @@ void MultimodalPromptProcessor::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -83,8 +82,8 @@ void MultimodalPromptProcessor::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -120,32 +119,29 @@ void MultimodalPromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast( kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -160,21 +156,22 @@ void MultimodalPromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } @@ -186,8 +183,7 @@ void MultimodalPromptProcessor::init_io( } // prepare embedding -template -void MultimodalPromptProcessor::prepare_io( +void MultimodalPromptProcessor::prepare_io( const TensorStruct& prompt_embedding, int32_t num_prompt_tokens, int64_t prompt_pos, @@ -208,8 +204,7 @@ void MultimodalPromptProcessor::prepare_io( } } -template -Result MultimodalPromptProcessor::prefill( +Result MultimodalPromptProcessor::prefill( const TensorStruct& prompt_embedding, int64_t start_pos, bool dump_logits, @@ -301,8 +296,4 @@ Result MultimodalPromptProcessor::prefill( return cur_token; } -// Explicit instantiations -template class MultimodalPromptProcessor; -template class MultimodalPromptProcessor; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h index fcfc07c9590..c2769ed9f50 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_prompt_processor.h @@ -16,8 +16,7 @@ namespace example { * @class MultimodalPromptProcessor * @brief Extended PromptProcessor with multimodal embedding support */ -template -class MultimodalPromptProcessor : public example::PromptProcessor { +class MultimodalPromptProcessor : public example::PromptProcessor { public: struct Metadata { int32_t context_len; @@ -33,9 +32,10 @@ class MultimodalPromptProcessor : public example::PromptProcessor { MultimodalPromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata); + Metadata metadata, + std::unique_ptr method_meta); int64_t get_num_heads() const { return metadata_.num_heads; @@ -74,34 +74,29 @@ class MultimodalPromptProcessor : public example::PromptProcessor { * @return Total I/O size in bytes. */ inline const size_t total_prompt_processor_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size + input_embedding_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size + input_embedding_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size + input_embedding_.size; } private: // Reuse members from token_generator - using PromptProcessor::decoder_runner_; - using PromptProcessor::kv_manager_; - using PromptProcessor::method_name_; - using PromptProcessor::k_cache_in_; - using PromptProcessor::v_cache_in_; - using PromptProcessor::k_cache_out_; - using PromptProcessor::v_cache_out_; - using PromptProcessor::input_toks_; - using PromptProcessor::input_pos_; - using PromptProcessor::attention_mask_; - using PromptProcessor::window_attention_mask_; - using PromptProcessor::logits_; - using PromptProcessor::inputs_; - using PromptProcessor::input_tensors_; - using PromptProcessor::output_tensors_; - using PromptProcessor::prompt_all_logits_; - using PromptProcessor::is_bert; + using PromptProcessor::attention_mask_; + using PromptProcessor::decoder_runner_; + using PromptProcessor::input_pos_; + using PromptProcessor::input_tensors_; + using PromptProcessor::input_toks_; + using PromptProcessor::inputs_; + using PromptProcessor::is_bert; + using PromptProcessor::k_cache_in_; + using PromptProcessor::k_cache_out_; + using PromptProcessor::kv_manager_; + using PromptProcessor::logits_; + using PromptProcessor::method_name_; + using PromptProcessor::output_tensors_; + using PromptProcessor::prompt_all_logits_; + using PromptProcessor::v_cache_in_; + using PromptProcessor::v_cache_out_; + using PromptProcessor::window_attention_mask_; /** * @brief Fill in I/O buffers with embedding data and position. diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp index 32e3baf27a9..32575994222 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.cpp @@ -74,17 +74,17 @@ void print_performance_report( void save_logits( const std::string& dump_logits_path, - const std::vector& prefill_logits, - const std::vector& decode_logits) { + const std::vector& prefill_logits, + const std::vector& decode_logits) { std::ofstream outFile(dump_logits_path.c_str(), std::ios::binary); if (outFile.is_open()) { outFile.write( reinterpret_cast(prefill_logits.data()), - prefill_logits.size() * sizeof(uint16_t)); + prefill_logits.size()); outFile.write( reinterpret_cast(decode_logits.data()), - decode_logits.size() * sizeof(uint16_t)); + decode_logits.size()); outFile.close(); } else { ET_CHECK_MSG(false, "Error saving the dump logits file"); @@ -93,8 +93,7 @@ void save_logits( } // namespace -template -QNNMultimodalRunner::QNNMultimodalRunner( +QNNMultimodalRunner::QNNMultimodalRunner( std::unique_ptr encoder, std::unique_ptr tok_embedding, std::unique_ptr text_decoder, @@ -148,16 +147,14 @@ QNNMultimodalRunner::QNNMultimodalRunner( ET_LOG(Info, "eval mode=%d", eval_mode_); } -template -bool QNNMultimodalRunner::is_loaded() const { +bool QNNMultimodalRunner::is_loaded() const { return encoder_->is_loaded() && tok_embedding_->is_loaded() && text_decoder_->is_loaded() && embedding_merger_ && tokenizer_ && decoder_runner_ && prompt_processor_ && token_generator_ && kv_manager_ && buffer_manager_; } -template -Error QNNMultimodalRunner::load() { +Error QNNMultimodalRunner::load() { if (is_loaded()) { return Error::Ok; } @@ -298,19 +295,22 @@ Error QNNMultimodalRunner::load() { sliding_window = ET_UNWRAP(text_decoder_->get("get_sliding_window")).toInt(); } - kv_manager_ = std::make_unique>(typename KVManager::Metadata{ - context_len_, - head_dim, - max_ar_len, - max_cache_len, - num_heads, - num_layers}); - - prompt_processor_ = std::make_unique>( + kv_manager_ = std::make_unique( + KVManager::Metadata{ + context_len_, + head_dim, + max_ar_len, + max_cache_len, + num_heads, + num_layers}, + std::make_unique(std::move( + text_decoder_->method_meta(token_generator_method_name).get()))); + + prompt_processor_ = std::make_unique( decoder_runner_.get(), kv_manager_.get(), prompt_processor_method_name, - typename MultimodalPromptProcessor::Metadata{ + MultimodalPromptProcessor::Metadata{ context_len_, num_heads, num_layers, @@ -319,7 +319,9 @@ Error QNNMultimodalRunner::load() { use_int64_token, sliding_window, cache_mode_, - static_cast(dim)}); + static_cast(dim)}, + std::make_unique(std::move( + text_decoder_->method_meta(prompt_processor_method_name).get()))); // Initialize EmbeddingGenerator tok_embedding_generator_ = std::make_unique( @@ -333,14 +335,14 @@ Error QNNMultimodalRunner::load() { static_cast(dim)}); if (eval_mode_ == EvalMode::kLookaheadDecoding) { // Initialize TokenGenerator - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), tok_embedding_generator_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename MultimodalLhdTokenGenerator::Metadata{ + MultimodalLhdTokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -353,16 +355,18 @@ Error QNNMultimodalRunner::load() { sliding_window, cache_mode_, static_cast(dim)}, - &stats_); + &stats_, + std::make_unique(std::move( + text_decoder_->method_meta(token_generator_method_name).get()))); } else { - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), tok_embedding_generator_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename MultimodalTokenGenerator::Metadata{ + MultimodalTokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -372,7 +376,9 @@ Error QNNMultimodalRunner::load() { sliding_window, cache_mode_, static_cast(dim)}, - &stats_); + &stats_, + std::make_unique(std::move( + text_decoder_->method_meta(token_generator_method_name).get()))); } buffer_manager_ = std::make_unique(); @@ -409,8 +415,7 @@ Error QNNMultimodalRunner::load() { return Error::Ok; } -template -executorch::runtime::Error QNNMultimodalRunner::generate( +executorch::runtime::Error QNNMultimodalRunner::generate( const std::vector& inputs, const llm::GenerationConfig& config, std::function token_callback, @@ -561,8 +566,7 @@ executorch::runtime::Error QNNMultimodalRunner::generate( return Error::Ok; } -template -Result QNNMultimodalRunner::get_model_version() { +Result QNNMultimodalRunner::get_model_version() { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -571,16 +575,11 @@ Result QNNMultimodalRunner::get_model_version() { return model_version_; } -template -Result QNNMultimodalRunner::get_encoder_method_meta() { +Result QNNMultimodalRunner::get_encoder_method_meta() { if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); } return encoder_->method_meta(kEncoderForwardName); } -// Explicit instantiations -template class QNNMultimodalRunner; -template class QNNMultimodalRunner; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h index 5407d5712b7..363ded0f055 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_runner.h @@ -66,12 +66,6 @@ inline Modality modality_of(const ModelVersion& model_version) { [](const auto& model) { return modality_of(model); }, model_version); } -enum KvBitWidth { - kWidth8 = 8, - kWidth16 = 16, -}; - -template class QNNMultimodalRunner : public executorch::extension::llm::MultimodalRunner { public: @@ -139,11 +133,11 @@ class QNNMultimodalRunner ModelVersion model_version_; std::unique_ptr buffer_manager_; - std::unique_ptr> kv_manager_; + std::unique_ptr kv_manager_; std::unique_ptr tokenizer_; std::unique_ptr decoder_runner_; - std::unique_ptr> prompt_processor_; - std::unique_ptr> token_generator_; + std::unique_ptr prompt_processor_; + std::unique_ptr token_generator_; std::unique_ptr encoder_runner_; std::unique_ptr tok_embedding_runner_; std::unique_ptr tok_embedding_processor_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp index 2ed8ae51f1d..e3f6f8e214e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.cpp @@ -15,17 +15,17 @@ using executorch::runtime::TensorInfo; namespace example { // Constructor with embedding runner support -template -MultimodalTokenGenerator::MultimodalTokenGenerator( +MultimodalTokenGenerator::MultimodalTokenGenerator( tokenizers::Tokenizer* tokenizer, TokenEmbeddingProcessor* tok_embedding_runner, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) - : TokenGenerator( + executorch::llm::Stats* stats, + std::unique_ptr method_meta) + : TokenGenerator( tokenizer, decoder_runner, kv_manager, @@ -39,7 +39,8 @@ MultimodalTokenGenerator::MultimodalTokenGenerator( metadata.use_int64_token, metadata.sliding_window, metadata.cache_mode}, - stats), + stats, + std::move(method_meta)), tok_embedding_runner_(tok_embedding_runner), metadata_(metadata) { // Set input_toks_.size to 0 since we use embeddings instead @@ -48,8 +49,7 @@ MultimodalTokenGenerator::MultimodalTokenGenerator( metadata_.ar_len * metadata_.embedding_dim * sizeof(float); } -template -void MultimodalTokenGenerator::init_io( +void MultimodalTokenGenerator::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -73,8 +73,7 @@ void MultimodalTokenGenerator::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -90,8 +89,8 @@ void MultimodalTokenGenerator::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -126,30 +125,27 @@ void MultimodalTokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast(kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -164,21 +160,22 @@ void MultimodalTokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } @@ -190,8 +187,7 @@ void MultimodalTokenGenerator::init_io( } // This function only considers the case where token_generator_ar_len equals 1. -template -void MultimodalTokenGenerator::prepare_io( +void MultimodalTokenGenerator::prepare_io( uint64_t cur_token, int64_t start_pos) { // Generate embedding for current token using embedding runner @@ -209,8 +205,4 @@ void MultimodalTokenGenerator::prepare_io( *input_pos_.data = static_cast(start_pos); } -// Explicit instantiations -template class MultimodalTokenGenerator; -template class MultimodalTokenGenerator; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h index 9eb9c79aaa4..2d0bf9385b4 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_token_generator.h @@ -16,8 +16,7 @@ namespace example { * @class MultimodalTokenGenerator * @brief Extended TokenGenerator with multimodal embedding support */ -template -class MultimodalTokenGenerator : public example::TokenGenerator { +class MultimodalTokenGenerator : public example::TokenGenerator { public: struct Metadata { int32_t context_len; @@ -36,11 +35,12 @@ class MultimodalTokenGenerator : public example::TokenGenerator { tokenizers::Tokenizer* tokenizer, TokenEmbeddingProcessor* tok_embedding_runner, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats); + executorch::llm::Stats* stats, + std::unique_ptr method_meta); virtual ~MultimodalTokenGenerator() = default; @@ -54,36 +54,31 @@ class MultimodalTokenGenerator : public example::TokenGenerator { override; inline const size_t total_token_generator_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size + input_embedding_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size + input_embedding_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size + input_embedding_.size; } protected: // Reuse members from token_generator - using TokenGenerator::kv_manager_; - using TokenGenerator::input_pos_; - using TokenGenerator::attention_mask_; - using TokenGenerator::window_attention_mask_; - using TokenGenerator::inputs_; - using TokenGenerator::input_tensors_; - using TokenGenerator::output_tensors_; + using TokenGenerator::attention_mask_; + using TokenGenerator::input_pos_; + using TokenGenerator::input_tensors_; + using TokenGenerator::inputs_; + using TokenGenerator::kv_manager_; + using TokenGenerator::output_tensors_; + using TokenGenerator::window_attention_mask_; // Additional members specific to multimodal TensorStruct input_embedding_; private: // Reuse members from token_generator - using TokenGenerator::input_toks_; - using TokenGenerator::logits_; - using TokenGenerator::k_cache_in_; - using TokenGenerator::v_cache_in_; - using TokenGenerator::k_cache_out_; - using TokenGenerator::v_cache_out_; + using TokenGenerator::input_toks_; + using TokenGenerator::k_cache_in_; + using TokenGenerator::k_cache_out_; + using TokenGenerator::logits_; + using TokenGenerator::v_cache_in_; + using TokenGenerator::v_cache_out_; // Additional members specific to multimodal TokenEmbeddingProcessor* tok_embedding_runner_; diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp index 59744d488bd..0cb52246a39 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.cpp @@ -17,12 +17,12 @@ using executorch::runtime::Span; using executorch::runtime::TensorInfo; namespace example { -template -PromptProcessor::PromptProcessor( +PromptProcessor::PromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata) + Metadata metadata, + std::unique_ptr method_meta) : decoder_runner_(decoder_runner), kv_manager_(kv_manager), method_name_(method_name), @@ -32,33 +32,41 @@ PromptProcessor::PromptProcessor( k_cache_out_.resize(metadata_.num_layers); v_cache_out_.resize(metadata_.num_layers); // Calculate I/O size + Result attention_mask = method_meta->input_tensor_meta(1); + Result logits = method_meta->output_tensor_meta(0); input_toks_.size = metadata_.ar_len * sizeof(int64_t); - if (is_bert()) + if (is_bert()) { input_pos_.size = 0; - else + } else { input_pos_.size = metadata_.ar_len * sizeof(int32_t); + } + attention_mask_.dtype = attention_mask->scalar_type(); + attention_mask_.size = metadata_.ar_len * metadata_.context_len * + attention_mask_.getElementSize(); switch (metadata_.cache_mode) { case CacheMode::StaticCahce: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); window_attention_mask_.size = 0; break; - case CacheMode::HybridCache: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); - window_attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + case CacheMode::HybridCache: { + Result window_attention_mask = + method_meta->input_tensor_meta(2); + window_attention_mask_.dtype = window_attention_mask->scalar_type(); + window_attention_mask_.size = metadata_.ar_len * metadata_.context_len * + window_attention_mask_.getElementSize(); break; + } default: ET_CHECK_MSG(false, "Unsupported llama cache mode"); break; } - logits_.size = metadata_.ar_len * metadata_.vocab_size * sizeof(uint16_t); + logits_.dtype = logits->scalar_type(); + logits_.size = + metadata_.ar_len * metadata_.vocab_size * logits_.getElementSize(); }; -template -void PromptProcessor::init_io( + +void PromptProcessor::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -80,8 +88,7 @@ void PromptProcessor::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -97,8 +104,8 @@ void PromptProcessor::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -136,33 +143,30 @@ void PromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast( kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); cache_inputs_.emplace_back(input_tensors_.back()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -177,21 +181,22 @@ void PromptProcessor::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } // Prepare the vector of EValue to run inference @@ -201,13 +206,11 @@ void PromptProcessor::init_io( } } -template -const std::vector& PromptProcessor::get_all_logits() { +const std::vector& PromptProcessor::get_all_logits() { return prompt_all_logits_; } -template -void PromptProcessor::prepare_io( +void PromptProcessor::prepare_io( const std::vector& prompt_tokens, int64_t prompt_pos, int64_t start_pos) { @@ -232,8 +235,7 @@ void PromptProcessor::prepare_io( } } -template -Result PromptProcessor::prefill( +Result PromptProcessor::prefill( std::vector prompt_tokens, int64_t start_pos, bool dump_logits, @@ -339,7 +341,9 @@ Result PromptProcessor::prefill( prompt_all_logits_.insert( prompt_all_logits_.end(), logits_.data, - logits_.data + metadata_.ar_len * metadata_.vocab_size); + logits_.data + + metadata_.ar_len * metadata_.vocab_size * + logits_.getElementSize()); } // In the last run, offset to the meaningful logits. if (i == num_iters - 1) { @@ -369,8 +373,4 @@ Result PromptProcessor::prefill( return cur_token; } -// Explicit instantiations -template class PromptProcessor; -template class PromptProcessor; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h index 599f7050d83..5317a8a77e1 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h +++ b/examples/qualcomm/oss_scripts/llama/runner/prompt_processor.h @@ -21,7 +21,7 @@ namespace example { * @class PromptProcessor * @brief Class for processing prompts using decoder and key-value manager. */ -template + class PromptProcessor { public: struct Metadata { @@ -36,9 +36,10 @@ class PromptProcessor { }; PromptProcessor( DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, - Metadata metadata); + Metadata metadata, + std::unique_ptr method_meta); virtual ~PromptProcessor() = default; @@ -55,9 +56,9 @@ class PromptProcessor { /** * @brief Get the all logits generated * - * @return std::vector& all the logits generated + * @return std::vector& all the logits generated */ - virtual const std::vector& get_all_logits(); + virtual const std::vector& get_all_logits(); /** * Prefill an LLM Module with the given text input. @@ -79,13 +80,8 @@ class PromptProcessor { * @return Total I/O size in bytes. */ inline const size_t total_prompt_processor_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size; } protected: @@ -105,7 +101,7 @@ class PromptProcessor { int64_t prompt_pos, int64_t start_pos); DecoderRunner* decoder_runner_; - KVManager* kv_manager_; + KVManager* kv_manager_; std::string method_name_; // metadata @@ -114,9 +110,9 @@ class PromptProcessor { // inputs and outputs TensorStruct input_toks_; TensorStruct input_pos_; - TensorStruct attention_mask_; - TensorStruct window_attention_mask_; - TensorStruct logits_; + TensorStructRaw attention_mask_; + TensorStructRaw window_attention_mask_; + TensorStructRaw logits_; // layer -> TensorImpl std::vector> k_cache_in_; @@ -131,6 +127,6 @@ class PromptProcessor { std::vector cache_inputs_; // Unused by default, only used when dump_logits_path is provided. - std::vector prompt_all_logits_; + std::vector prompt_all_logits_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp index 0a4a8b9abb5..7257e869dcc 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.cpp @@ -66,17 +66,17 @@ void print_performance_report( void save_logits( const std::string& dump_logits_path, - const std::vector& prefill_logits, - const std::vector& decode_logits) { + const std::vector& prefill_logits, + const std::vector& decode_logits) { std::ofstream outFile(dump_logits_path.c_str(), std::ios::binary); if (outFile.is_open()) { outFile.write( reinterpret_cast(prefill_logits.data()), - prefill_logits.size() * sizeof(uint16_t)); + prefill_logits.size()); outFile.write( reinterpret_cast(decode_logits.data()), - decode_logits.size() * sizeof(uint16_t)); + decode_logits.size()); outFile.close(); } else { ET_CHECK_MSG(false, "Error saving the dump logits file"); @@ -85,8 +85,7 @@ void save_logits( } // namespace -template -Runner::Runner( +Runner::Runner( std::unique_ptr module, const std::string& decoder_model_version, const std::string& model_path, @@ -152,14 +151,12 @@ Runner::Runner( ET_LOG(Info, "eval mode=%d", eval_mode_); } -template -bool Runner::is_loaded() const { +bool Runner::is_loaded() const { return module_->is_loaded() && tokenizer_ && decoder_runner_ && prompt_processor_ && token_generator_ && kv_manager_ && buffer_manager_; } -template -Error Runner::load() { +Error Runner::load() { if (is_loaded()) { return Error::Ok; } @@ -275,13 +272,16 @@ Error Runner::load() { if (module_->method_names()->count("get_sliding_window") > 0) { sliding_window = ET_UNWRAP(module_->get("get_sliding_window")).toInt(); } - kv_manager_ = std::make_unique>(typename KVManager::Metadata{ - context_len_, - head_dim, - max_ar_len, - max_cache_len, - num_heads, - num_layers}); + kv_manager_ = std::make_unique( + KVManager::Metadata{ + context_len_, + head_dim, + max_ar_len, + max_cache_len, + num_heads, + num_layers}, + std::make_unique( + std::move(module_->method_meta(token_generator_method_name).get()))); if (attention_sink_rope_module_ != nullptr) { attention_sink_rope_runner_ = std::make_unique( @@ -290,11 +290,11 @@ Error Runner::load() { attention_sink_rope_runner_->load(method_names)); } - prompt_processor_ = std::make_unique>( + prompt_processor_ = std::make_unique( decoder_runner_.get(), kv_manager_.get(), prompt_processor_method_name, - typename PromptProcessor::Metadata{ + PromptProcessor::Metadata{ context_len_, num_heads, num_layers, @@ -302,15 +302,17 @@ Error Runner::load() { vocab_size, use_int64_token, sliding_window, - cache_mode_}); + cache_mode_}, + std::make_unique( + std::move(module_->method_meta(prompt_processor_method_name).get()))); if (eval_mode_ == EvalMode::kLookaheadDecoding) { - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename LhdTokenGenerator::Metadata{ + LhdTokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -322,15 +324,17 @@ Error Runner::load() { gcap_, sliding_window, cache_mode_}, - &stats_); + &stats_, + std::make_unique(std::move( + module_->method_meta(token_generator_method_name).get()))); } else { - token_generator_ = std::make_unique>( + token_generator_ = std::make_unique( tokenizer_.get(), decoder_runner_.get(), kv_manager_.get(), token_generator_method_name, std::move(eos_ids), - typename TokenGenerator::Metadata{ + TokenGenerator::Metadata{ context_len_, num_heads, num_layers, @@ -339,7 +343,9 @@ Error Runner::load() { use_int64_token, sliding_window, cache_mode_}, - &stats_); + &stats_, + std::make_unique(std::move( + module_->method_meta(token_generator_method_name).get()))); } buffer_manager_ = std::make_unique(); @@ -360,8 +366,7 @@ Error Runner::load() { return Error::Ok; } -template -Error Runner::generate( +Error Runner::generate( const std::string& prompt, const llm::GenerationConfig& config, std::function token_callback, @@ -370,8 +375,7 @@ Error Runner::generate( prompt, false, config, token_callback, stats_callback); } -template -Error Runner::generate_from_prompt_or_file( +Error Runner::generate_from_prompt_or_file( const std::string& prompt, bool tokenized_prompt, const llm::GenerationConfig& config, @@ -500,8 +504,7 @@ Error Runner::generate_from_prompt_or_file( return Error::Ok; } -template -Result Runner::get_decoder_model_version() { +Result Runner::get_decoder_model_version() { if (!is_loaded()) { stats_.model_load_start_ms = time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -510,8 +513,4 @@ Result Runner::get_decoder_model_version() { return decoder_model_version_; } -// Explicit instantiations -template class Runner; -template class Runner; - } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/runner.h b/examples/qualcomm/oss_scripts/llama/runner/runner.h index 39ce62c2d9f..5d03a12f61a 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/runner.h +++ b/examples/qualcomm/oss_scripts/llama/runner/runner.h @@ -46,12 +46,6 @@ enum DecoderModelVersion { kGemma2, }; -enum KvBitWidth { - kWidth8 = 8, - kWidth16 = 16, -}; - -template class Runner : public executorch::extension::llm::IRunner { public: explicit Runner( @@ -121,14 +115,15 @@ class Runner : public executorch::extension::llm::IRunner { DecoderModelVersion decoder_model_version_; std::unique_ptr buffer_manager_; - std::unique_ptr> kv_manager_; + std::unique_ptr kv_manager_; std::unique_ptr tokenizer_; std::unique_ptr decoder_runner_; std::unique_ptr attention_sink_rope_runner_; - std::unique_ptr> prompt_processor_; - std::unique_ptr> token_generator_; + std::unique_ptr prompt_processor_; + std::unique_ptr token_generator_; // stats executorch::llm::Stats stats_; }; + } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp index 8ab82d932e1..098fcf9efa6 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.cpp @@ -17,15 +17,15 @@ using executorch::runtime::Span; using executorch::runtime::TensorInfo; namespace example { -template -TokenGenerator::TokenGenerator( +TokenGenerator::TokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats) + executorch::llm::Stats* stats, + std::unique_ptr method_meta) : tokenizer_(tokenizer), decoder_runner_(decoder_runner), kv_manager_(kv_manager), @@ -39,32 +39,37 @@ TokenGenerator::TokenGenerator( v_cache_out_.resize(metadata_.num_layers); // Calculate I/O size + Result attention_mask = method_meta->input_tensor_meta(1); + Result logits = method_meta->output_tensor_meta(0); + input_toks_.size = metadata_.ar_len * sizeof(int64_t); input_pos_.size = metadata_.ar_len * sizeof(int32_t); - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + attention_mask_.dtype = attention_mask->scalar_type(); + attention_mask_.size = metadata_.ar_len * metadata_.context_len * + attention_mask_.getElementSize(); switch (metadata_.cache_mode) { case CacheMode::StaticCahce: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); window_attention_mask_.size = 0; break; - case CacheMode::HybridCache: - attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); - window_attention_mask_.size = - metadata_.ar_len * metadata_.context_len * sizeof(uint16_t); + case CacheMode::HybridCache: { + Result window_attention_mask = + method_meta->input_tensor_meta(2); + window_attention_mask_.dtype = window_attention_mask->scalar_type(); + window_attention_mask_.size = metadata_.ar_len * metadata_.context_len * + window_attention_mask_.getElementSize(); break; + } default: ET_CHECK_MSG(false, "Unsupported llama cache mode"); break; } - logits_.size = metadata_.ar_len * metadata_.vocab_size * sizeof(uint16_t); + logits_.dtype = logits->scalar_type(); + logits_.size = + metadata_.ar_len * metadata_.vocab_size * logits_.getElementSize(); } -template -void TokenGenerator::init_io( +void TokenGenerator::init_io( IMemAlloc* buffer_manager, Result method_meta) { size_t idx = 0; @@ -86,8 +91,7 @@ void TokenGenerator::init_io( // [I]: attention_mask Result attention_mask = method_meta->input_tensor_meta(idx++); - attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(attention_mask_.size)); + attention_mask_.data = buffer_manager->allocate(attention_mask_.size); attention_mask_.tensor = std::make_unique( attention_mask->scalar_type(), attention_mask->sizes().size(), @@ -103,8 +107,8 @@ void TokenGenerator::init_io( if (metadata_.cache_mode == CacheMode::HybridCache) { Result window_attention_mask = method_meta->input_tensor_meta(idx++); - window_attention_mask_.data = reinterpret_cast( - buffer_manager->allocate(window_attention_mask_.size)); + window_attention_mask_.data = + buffer_manager->allocate(window_attention_mask_.size); window_attention_mask_.tensor = std::make_unique( window_attention_mask->scalar_type(), window_attention_mask->sizes().size(), @@ -141,31 +145,28 @@ void TokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_in_ : v_cache_in_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->input_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].buffer; - cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].buffer, const_cast(kv_cache->dim_order().data())); input_tensors_.emplace_back(cache[layer].get()); cache_inputs_.emplace_back(input_tensors_.back()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].buffer, cache[layer]->nbytes(), kv_cache.get()); } } // [O]: logits Result logits = method_meta->output_tensor_meta(0); - logits_.data = - reinterpret_cast(buffer_manager->allocate(logits_.size)); + logits_.data = buffer_manager->allocate(logits_.size); logits_.tensor = std::make_unique( logits->scalar_type(), logits->sizes().size(), @@ -180,21 +181,22 @@ void TokenGenerator::init_io( for (int cache_group = 0; cache_group < 2; ++cache_group) { std::vector>& cache = (cache_group == 0 ? k_cache_out_ : v_cache_out_); - std::vector> cache_ptrs = (cache_group == 0) + std::vector cache_ptrs = (cache_group == 0) ? kv_manager_->get_k_cache_() : kv_manager_->get_v_cache_(); for (int layer = 0; layer < metadata_.num_layers; ++layer, ++index) { Result kv_cache = method_meta->output_tensor_meta(index); - T* cache_ptr = cache_ptrs[layer].output_buffer; cache[layer] = std::make_unique( kv_cache->scalar_type(), kv_cache->sizes().size(), const_cast(kv_cache->sizes().data()), - cache_ptr, + cache_ptrs[layer].output_buffer, const_cast(kv_cache->dim_order().data())); output_tensors_.emplace_back(cache[layer].get()); buffer_manager->add_memory_info( - cache_ptr, cache[layer]->nbytes(), kv_cache.get()); + cache_ptrs[layer].output_buffer, + cache[layer]->nbytes(), + kv_cache.get()); } } // Prepare the vector of EValue to run inference @@ -204,14 +206,12 @@ void TokenGenerator::init_io( } } -template -const std::vector& TokenGenerator::get_all_logits() { +const std::vector& TokenGenerator::get_all_logits() { return token_all_logits_; } // This function only considers the case where token_generator_ar_len equals 1. -template -void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { +void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { // update input_tok *input_toks_.data = metadata_.use_int64_token ? cur_token : static_cast(cur_token); @@ -219,8 +219,7 @@ void TokenGenerator::prepare_io(uint64_t cur_token, int64_t start_pos) { *input_pos_.data = static_cast(start_pos); } -template -Result TokenGenerator::generate( +Result TokenGenerator::generate( std::vector tokens, int64_t start_pos, int32_t seq_len, @@ -306,7 +305,9 @@ Result TokenGenerator::generate( token_all_logits_.insert( token_all_logits_.end(), logits_.data, - logits_.data + metadata_.ar_len * metadata_.vocab_size); + logits_.data + + metadata_.ar_len * metadata_.vocab_size * + logits_.getElementSize()); } ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); executorch::aten::Tensor& logits_tensor = logits_res.get(); @@ -374,8 +375,5 @@ Result TokenGenerator::generate( return pos - start_pos; } -// Explicit instantiations -template class TokenGenerator; -template class TokenGenerator; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h index 7f9264b1102..6945d907a76 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/token_generator.h +++ b/examples/qualcomm/oss_scripts/llama/runner/token_generator.h @@ -22,7 +22,7 @@ namespace example { * @class TokenGenerator * @brief Class for generating the token using decoder and key-value manager. */ -template + class TokenGenerator { public: struct Metadata { @@ -38,11 +38,12 @@ class TokenGenerator { TokenGenerator( tokenizers::Tokenizer* tokenizer, DecoderRunner* decoder_runner, - KVManager* kv_manager, + KVManager* kv_manager, const std::string& method_name, std::unique_ptr>&& eos_ids, Metadata metadata, - executorch::llm::Stats* stats); + executorch::llm::Stats* stats, + std::unique_ptr method_meta); virtual ~TokenGenerator() = default; /** @@ -58,9 +59,9 @@ class TokenGenerator { /** * @brief Get the all logits generated * - * @return std::vector& all the logits generated + * @return std::vector& all the logits generated */ - virtual const std::vector& get_all_logits(); + virtual const std::vector& get_all_logits(); /**    * @brief Generate tokens. @@ -78,28 +79,23 @@ class TokenGenerator { bool dump_logits, AttentionSinkRopeRunner* attention_sink_rope_runner); inline const size_t total_token_generator_io_size_in_bytes() const { - if (metadata_.cache_mode == CacheMode::HybridCache) { - return input_toks_.size + input_pos_.size + attention_mask_.size + - window_attention_mask_.size + logits_.size; - } else { - return input_toks_.size + input_pos_.size + attention_mask_.size + - logits_.size; - } + return input_toks_.size + input_pos_.size + attention_mask_.size + + window_attention_mask_.size + logits_.size; } protected: tokenizers::Tokenizer* tokenizer_; DecoderRunner* decoder_runner_; - KVManager* kv_manager_; + KVManager* kv_manager_; std::string method_name_; std::unique_ptr> eos_ids_; // inputs and outputs TensorStruct input_toks_; TensorStruct input_pos_; - TensorStruct attention_mask_; - TensorStruct window_attention_mask_; - TensorStruct logits_; + TensorStructRaw attention_mask_; + TensorStructRaw window_attention_mask_; + TensorStructRaw logits_; // layer -> TensorImpl std::vector> k_cache_in_; @@ -128,6 +124,6 @@ class TokenGenerator { Metadata metadata_; // Unused by default, only used when dump_logits_path is provided. - std::vector token_all_logits_; + std::vector token_all_logits_; }; } // namespace example diff --git a/examples/qualcomm/oss_scripts/llama/runner/utils.h b/examples/qualcomm/oss_scripts/llama/runner/utils.h index bef6b1a2017..df6dddfdc6e 100644 --- a/examples/qualcomm/oss_scripts/llama/runner/utils.h +++ b/examples/qualcomm/oss_scripts/llama/runner/utils.h @@ -8,10 +8,16 @@ #pragma once #include +#include #include #include // Template struct to hold tensor data and tensor + +// TODO: Refactor these struct to use TensorPtr +// see https://docs.pytorch.org/executorch/stable/extension-tensor.html + +// TensorStruct whose dtype known in compile time template struct TensorStruct { std::unique_ptr tensor; @@ -20,3 +26,38 @@ struct TensorStruct { // data size in bytes size_t size; }; + +inline size_t getDtypeSize(executorch::aten::ScalarType dtype) { + switch (dtype) { + case executorch::aten::ScalarType::Float: + return sizeof(float); + case executorch::aten::ScalarType::Double: + return sizeof(double); + case executorch::aten::ScalarType::Int: + return sizeof(int32_t); + case executorch::aten::ScalarType::Long: + return sizeof(int64_t); + case executorch::aten::ScalarType::Byte: + return sizeof(uint8_t); + case executorch::aten::ScalarType::UInt16: + return sizeof(uint16_t); + default: + ET_CHECK_MSG( + false, + "Unsupported scalar type %s", + executorch::runtime::toString(dtype)); + break; + } +} + +// TensorStruct whose dtype known in runtime, and raw file is used +struct TensorStructRaw { + std::unique_ptr tensor; + std::byte* data; + // data size in bytes + size_t size; + executorch::aten::ScalarType dtype; + size_t getElementSize() const { + return getDtypeSize(dtype); + } +}; diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py index 48386f181d8..de857dfc17c 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/attention_sink_wrappers.py @@ -13,6 +13,7 @@ import torch from executorch.backends.qualcomm._passes import TagQuantIO +from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, ) @@ -460,6 +461,7 @@ def compile(self, attention_sink_evictor_pte_path: str): alloc_graph_input=False, alloc_graph_output=False, ), + passes=[BuildQuantIo()], extract_delegate_segments=True, ) exec_prog_mgr = edge_prog_mgr.to_executorch(executorch_config) diff --git a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py index ef72e0765fd..c7a831824ce 100644 --- a/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py +++ b/examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py @@ -19,6 +19,7 @@ import torch from executorch.backends.qualcomm._passes import FoldQDQ, I64toI32, TagQuantIO +from executorch.backends.qualcomm._passes.build_quant_io import BuildQuantIo from executorch.backends.qualcomm._passes.qnn_pass_manager import ( get_capture_program_passes, ) @@ -269,10 +270,18 @@ def permute(w, heads, partial_rotary_dim): QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY ]["skip_node"] = {"tokens"} - if tok_embedding is not None: - tok_embedding = tok_embedding.eval() + with torch.no_grad(): + if self.apply_embedding: + tok_embedding = torch.export.export( + tok_embedding.eval(), + tok_embedding.get_example_input(), + strict=True, + ).module() - return tok_embedding, decoder.eval() + decoder = torch.export.export( + decoder.eval(), self.export_input, strict=True + ).module() + return tok_embedding, decoder def _get_model_instance(self) -> LlamaModel: if self.mode == Mode.PREFILL and self.control_args.model_mode == "kv": @@ -607,23 +616,28 @@ def quantize(self, request: Request): # noqa: C901 ): return + data = request.method_data[TEXT_DECODER] # check bit width graph io fixed_point_type = {"kv_type": torch.float32, "io_type": torch.float32} - if self.quant_recipe.get_kv_io_bit_width() == 8: - fixed_point_type["kv_type"] = torch.uint8 - elif self.quant_recipe.get_kv_io_bit_width() == 16: - fixed_point_type["kv_type"] = torch.uint16 + if data.skip_quantize: + # already init as float32 + return else: - raise RuntimeError( - f"unknown kv io bit width {self.quant_recipe.get_kv_io_bit_width()}" - ) + if self.quant_recipe.get_kv_io_bit_width() == 8: + fixed_point_type["kv_type"] = torch.uint8 + elif self.quant_recipe.get_kv_io_bit_width() == 16: + fixed_point_type["kv_type"] = torch.uint16 + else: + raise RuntimeError( + f"unknown kv io bit width {self.quant_recipe.get_kv_io_bit_width()}" + ) - if self.quant_recipe.get_logits_output_bit_width() == 16: - fixed_point_type["io_type"] = torch.uint16 - else: - raise RuntimeError( - f"unknown logits io bit width {self.quant_recipe.get_logits_output_bit_width()}" - ) + if self.quant_recipe.get_logits_output_bit_width() == 16: + fixed_point_type["io_type"] = torch.uint16 + else: + raise RuntimeError( + f"unknown logits io bit width {self.quant_recipe.get_logits_output_bit_width()}" + ) data = request.method_data[TEXT_DECODER] audio_turns = request.method_data[ @@ -654,18 +668,6 @@ def quantize(self, request: Request): # noqa: C901 ) with torch.no_grad(): - # prepare tok embedding model for ptq - if self.apply_embedding: - self.tok_embedding = torch.export.export( - self.tok_embedding, - self.tok_embedding.get_example_input(), - strict=True, - ).module() - - # prepare decoder model for ptq - self.decoder = torch.export.export( - self.decoder, self.export_input, strict=True - ).module() if self.control_args.quant_recipe_suggestion: graph_module = copy.deepcopy(self.decoder) @@ -973,6 +975,7 @@ def compile(self, request: Request): # noqa: C901 alloc_graph_input=False, alloc_graph_output=False, ), + passes=[BuildQuantIo()], ) tok_embedding_exec_prog_mgr = tok_embedding_edge_prog_mgr.to_executorch( executorch_config @@ -1009,6 +1012,7 @@ def compile(self, request: Request): # noqa: C901 alloc_graph_input=False, alloc_graph_output=False, ), + passes=[BuildQuantIo()], ) exec_prog_mgr = edge_prog_mgr.to_executorch(executorch_config) data = request.method_data[TEXT_DECODER] @@ -1127,7 +1131,9 @@ def compile(self, request: Request): if self.control_args.verbose: print_delegation_info(edge_prog_mgr.exported_program().graph_module) - exec_prog_mgr = edge_prog_mgr.to_executorch(ExecutorchBackendConfig()) + exec_prog_mgr = edge_prog_mgr.to_executorch( + ExecutorchBackendConfig(passes=[BuildQuantIo()]) + ) data = request.method_data[self.modality] with open( f"{self.control_args.artifact}/{data.pte_filename}.pte", "wb" diff --git a/exir/passes/spec_prop_pass.py b/exir/passes/spec_prop_pass.py index 9adbf65dd90..8662b7e5bee 100644 --- a/exir/passes/spec_prop_pass.py +++ b/exir/passes/spec_prop_pass.py @@ -9,8 +9,12 @@ import operator from typing import Optional +# register llama.fallback +import executorch.extension.llm.custom_ops.op_fallback # noqa: F401 + import torch from executorch.exir.delegate import executorch_call_delegate +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, ProxyValue from executorch.exir.tensor import TensorSpec from torch.export.exported_program import ExportGraphSignature @@ -75,9 +79,9 @@ def get_spec(x): elif node.op == "call_function" and node.target == operator.getitem: value_spec = pytree.tree_map(get_spec, node.args[0]) node.meta["spec"] = value_spec[node.args[1]] - elif ( - node.op == "call_function" - and node.target == executorch_call_delegate + elif node.op == "call_function" and node.target in ( + executorch_call_delegate, + exir_ops.edge.llama.fallback.default, ): # Note: We currently rely on delegate node specs not being regenerated, # as the spec is set somewhat manually when adding the call delegate node. diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index e072694f913..b9215f978bc 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -206,41 +206,14 @@ class ExecuTorchLlmJni : public facebook::jni::HybridClass { data_files_vector, cpp_load_mode); std::string decoder_model = "llama3"; // use llama3 for now - // Using 8bit as default since this meta is introduced with 16bit kv io - // support and older models only have 8bit kv io. - example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8; - if (module->method_names()->count("get_kv_io_bit_width") > 0) { - kv_bitwidth = static_cast( - module->get("get_kv_io_bit_width") - .get() - .toScalar() - .to()); - } - - if (kv_bitwidth == example::KvBitWidth::kWidth8) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else if (kv_bitwidth == example::KvBitWidth::kWidth16) { - runner_ = std::make_unique>( - std::move(module), - decoder_model.c_str(), - model_path->toStdString().c_str(), - tokenizer_path->toStdString().c_str(), - "", - "", - temperature_); - } else { - ET_CHECK_MSG( - false, - "Unsupported kv bitwidth: %ld", - static_cast(kv_bitwidth)); - } + runner_ = std::make_unique( + std::move(module), + decoder_model.c_str(), + model_path->toStdString().c_str(), + tokenizer_path->toStdString().c_str(), + "", + "", + temperature_); model_type_category_ = MODEL_TYPE_CATEGORY_LLM; #endif #if defined(EXECUTORCH_BUILD_MEDIATEK) diff --git a/extension/llm/custom_ops/model_sharding.py b/extension/llm/custom_ops/model_sharding.py index 6838b0958a2..916b13a90b8 100644 --- a/extension/llm/custom_ops/model_sharding.py +++ b/extension/llm/custom_ops/model_sharding.py @@ -7,8 +7,9 @@ import re from typing import List -import torch +import executorch.extension.llm.custom_ops.op_fallback # noqa: F401 +import torch from executorch.backends.qualcomm.utils.constants import ( QCOM_PASS_ACTIVATE_KEY, QCOM_PASS_ARGS_KWARGS_DEFAULTS_KEY, @@ -17,27 +18,6 @@ from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult from torch.export.exported_program import ExportedProgram -from torch.library import impl, Library - - -fallback_op_lib = Library("llama", "DEF") -# registering an operator. -fallback_op_lib.define("fallback(Tensor input) -> Tensor") - - -@impl(fallback_op_lib, "fallback") -def fallback_impl(a: torch.Tensor) -> torch.Tensor: - return a - - -# registering the out variant. -fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)") - - -@impl(fallback_op_lib, "fallback.out") -def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: - out.copy_(a) - return out class SplitGraph(ExportPass): diff --git a/extension/llm/custom_ops/op_fallback.py b/extension/llm/custom_ops/op_fallback.py new file mode 100644 index 00000000000..e94c81db51a --- /dev/null +++ b/extension/llm/custom_ops/op_fallback.py @@ -0,0 +1,29 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# pyre-ignore-all-errors + +import torch + +from torch.library import impl, Library + +fallback_op_lib = Library("llama", "DEF") +# registering an operator. +fallback_op_lib.define("fallback(Tensor input) -> Tensor") + + +@impl(fallback_op_lib, "fallback") +def fallback_impl(a: torch.Tensor) -> torch.Tensor: + return a + + +# registering the out variant. +fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)") + + +@impl(fallback_op_lib, "fallback.out") +def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: + out.copy_(a) + return out