Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions backends/qualcomm/_passes/build_quant_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
# LICENSE file in the root directory of this source tree.
import torch
from executorch.backends.qualcomm.utils.constants import QCOM_QUANTIZED_IO
from executorch.exir.delegate import executorch_call_delegate

from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.delegate import executorch_call_delegate
from executorch.exir.pass_base import ExportPass, PassResult
from executorch.exir.tensor import TensorSpec
from torch.utils import _pytree as pytree


class BuildQuantIo(ExportPass):
Expand All @@ -28,22 +27,27 @@ def _make_spec(self, x):
else:
return None

def placeholder(self, name: str, arg, meta):
if quantized_dtype := meta.data.get(QCOM_QUANTIZED_IO, None):
arg = arg.to(dtype=quantized_dtype)
meta["spec"] = self._make_spec(arg)
return super().placeholder(name, arg, meta)

def call_getitem(self, value, key: int, meta):
meta["spec"] = value.node.meta["spec"][key]
return super().call_getitem(value, key, meta)

def call_delegate(self, lowered_module, args, kwargs, meta):
args_data, _ = pytree.tree_map_only(
ProxyValue, lambda x: x.data, (args, kwargs)
)
meta["spec"] = pytree.tree_map(
self._make_spec,
executorch_call_delegate(lowered_module, *args_data),
)
return super().call_delegate(lowered_module, args, kwargs, meta)
def _build(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
# Forcedly update delegate node's meta['spec'] to get correct output
# tensor size in runtime
call_delegates = [
node
for node in graph_module.graph.nodes
if node.op == "call_function" and node.target == executorch_call_delegate
]
for n in graph_module.graph.nodes:
if QCOM_QUANTIZED_IO in n.meta:
n.meta["val"] = n.meta["val"].to(dtype=n.meta[QCOM_QUANTIZED_IO])
n.meta["spec"] = self._make_spec(n.meta["val"])

for call_delegate in call_delegates:
spec = []
for user in list(call_delegate.users):
spec.append(self._make_spec(user.meta["val"]))
call_delegate.meta["spec"] = tuple(spec)

def call(self, graph_module: torch.fx.GraphModule):
self._build(graph_module)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()
return PassResult(graph_module, True)
15 changes: 14 additions & 1 deletion backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7475,6 +7475,8 @@ def test_llama_stories_110m(self):
"--max_context_len",
"128",
]
if self.use_fp16:
cmds.append("--use_fp16")
self.add_default_cmds(cmds)

golden_start_with = "Once upon a time,"
Expand All @@ -7495,7 +7497,10 @@ def test_llama_stories_110m(self):
# x86 does not allow weight sharing, so we don't check pte size
if not self.enable_x86_64:
pte_size = msg["pte_size"]
self.assertLessEqual(pte_size, 135_000_000) # 135MB
if self.use_fp16:
self.assertLessEqual(pte_size, 275_000_000) # 275MB
else:
self.assertLessEqual(pte_size, 135_000_000) # 135MB
if not self.compile_only and not self.enable_x86_64:
self.assertGreaterEqual(msg["inference_speed"], 220) # Lanai

Expand Down Expand Up @@ -9830,6 +9835,13 @@ def setup_environment():
choices=["wikitext_ppl", "hellaswag_acc_norm", "sqnr"],
type=str,
)
parser.add_argument(
"-F",
"--use_fp16",
help="If specified, will run in fp16 precision and discard ptq setting",
action="store_true",
default=False,
)

args, ns_args = parser.parse_known_args(namespace=unittest)
TestQNN.host = args.host
Expand Down Expand Up @@ -9858,6 +9870,7 @@ def setup_environment():
TestQNN.backend = args.backend
TestQNN.static_llm_eval_method = args.static_llm_eval_method
TestQNN.direct_build_folder = args.direct_build_folder
TestQNN.use_fp16 = args.use_fp16

return sys.argv[:1] + ns_args

Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class TestQNN(unittest.TestCase):
inference_speed_output_path = "outputs/inference_speed.txt"
static_llm_eval_method = ""
direct_build_folder: str = ""
use_fp16 = False

@classmethod
def setUpClass(cls):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def _init_runner_base_cmd(self):
base_cmd = " ".join(
[
f"export LD_LIBRARY_PATH={self.qnn_sdk}/lib/x86_64-linux-clang/:{args.build_folder}/lib &&",
f"./{args.build_folder}/examples/qualcomm/oss_scripts/llama/{self.runner}",
f"{args.build_folder}/examples/qualcomm/oss_scripts/llama/{self.runner}",
f"--decoder_model_version {DECODER_MODEL_VERSION[args.decoder_model]}",
f"--tokenizer_path {self.runtime_tokenizer_path}",
f"--output_path {self.device_output_response_path}",
Expand Down
64 changes: 50 additions & 14 deletions examples/qualcomm/oss_scripts/llama/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)

from executorch.backends.qualcomm.utils.utils import (
generate_gpu_compiler_spec,
generate_htp_compiler_spec,
generate_qnn_executorch_compiler_spec,
get_soc_to_chipset_map,
Expand Down Expand Up @@ -119,9 +120,15 @@ def compile(
# because the encoder is quite sensitive and quantization can make it harder for the model to distinguish
# between images within the same conversation.
to_skip = len(args.image_path) > 1
backend_options = generate_htp_compiler_spec(
use_fp16=to_skip,
)
if args.backend == "htp":
backend_options = generate_htp_compiler_spec(
use_fp16=to_skip,
)
elif args.backend == "gpu":
backend_options = generate_gpu_compiler_spec()
else:
raise ValueError(f"Unsupported backend {args.backend}")

encoder_compile_specs = generate_qnn_executorch_compiler_spec(
soc_model=get_soc_to_chipset_map()[args.soc_model],
backend_options=backend_options,
Expand All @@ -131,34 +138,48 @@ def compile(
skip_quantize[modality] = to_skip
compile_specs[modality] = encoder_compile_specs
elif is_multimodal and modality == TOK_EMBEDDING:
backend_options = generate_htp_compiler_spec(
use_fp16=False,
# x86 emulator does not support weight sharing
use_weight_sharing=not args.enable_x86_64,
)
if args.backend == "htp":
backend_options = generate_htp_compiler_spec(
use_fp16=False,
# x86 emulator does not support weight sharing
use_weight_sharing=not args.enable_x86_64,
)
elif args.backend == "gpu":
backend_options = generate_gpu_compiler_spec()
else:
raise ValueError(f"Unsupported backend {args.backend}")

compile_specs[modality] = [
generate_qnn_executorch_compiler_spec(
soc_model=get_soc_to_chipset_map()[args.soc_model],
backend_options=backend_options,
# x86 emulator does not support shared buffer
shared_buffer=not args.enable_x86_64,
online_prepare=args.online_prepare,
)
] * len(TOK_EMBEDDING_GRAPH_NAMES)
elif modality == TEXT_DECODER:
# compile spec for text decoder
backend_options = generate_htp_compiler_spec(
use_fp16=False,
use_multi_contexts=decoder_model_config.num_sharding > 1,
# x86 emulator does not support weight sharing
use_weight_sharing=not args.enable_x86_64,
)
if args.backend == "htp":
backend_options = generate_htp_compiler_spec(
use_fp16=args.use_fp16,
use_multi_contexts=decoder_model_config.num_sharding > 1,
# x86 emulator does not support weight sharing
use_weight_sharing=not args.enable_x86_64,
)
elif args.backend == "gpu":
backend_options = generate_gpu_compiler_spec()
else:
raise ValueError(f"Unsupported backend {args.backend}")
skip_quantize[modality] = args.use_fp16
compile_specs[modality] = [
generate_qnn_executorch_compiler_spec(
soc_model=get_soc_to_chipset_map()[args.soc_model],
backend_options=backend_options,
# x86 emulator does not support shared buffer
shared_buffer=not args.enable_x86_64,
use_mha2sha=True,
online_prepare=args.online_prepare,
)
] * len(DECODER_GRAPH_NAMES)

Expand Down Expand Up @@ -529,6 +550,14 @@ def _build_parser():
help="Number of examples in few-shot context",
)

parser.add_argument(
"-F",
"--use_fp16",
help="If specified, will run in fp16 precision and discard ptq setting",
action="store_true",
default=False,
)

parser.add_argument("-v", "--verbose", action="store_true")

parser.add_argument(
Expand Down Expand Up @@ -592,6 +621,12 @@ def export_llama(args) -> None:
pte_filename = "lookahead_llama_qnn"
else:
raise RuntimeError(f"Unknown model_mode: {args.model_mode}.")

if args.model_mode == "hybrid" and args.online_prepare:
raise RuntimeError(
"Currently hybrid mode is not compatible with online_prepare."
)

if args.decoder_model == "stories260k":
pte_filename = f"{args.decoder_model}_" + pte_filename
pte_filenames = {
Expand Down Expand Up @@ -740,6 +775,7 @@ def export_llama(args) -> None:
def main():
parser = _build_parser()
args = parser.parse_args()
args.build_folder = os.path.realpath(args.build_folder)
try:
export_llama(args)
except Exception as e:
Expand Down
25 changes: 3 additions & 22 deletions examples/qualcomm/oss_scripts/llama/qnn_llama_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ std::string get_formatted_prompt(
return formatted_prompt;
}

template <typename T>
void start_runner(
std::unique_ptr<executorch::extension::Module> module,
std::vector<std::string>& prompts,
Expand All @@ -219,7 +218,7 @@ void start_runner(
gflags::GetCommandLineFlagInfoOrDie("tokenized_prompt").is_default ? false
: true;
// create llama runner
example::Runner<T> runner(
example::Runner runner(
std::move(module),
FLAGS_decoder_model_version.c_str(),
FLAGS_model_path.c_str(),
Expand Down Expand Up @@ -296,26 +295,8 @@ int main(int argc, char** argv) {
FLAGS_attention_sink_rope_path.c_str(),
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);
}
// Using 8bit as default since this meta is introduced with 16bit kv io
// support and older models only have 8bit kv io.
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
if (module->method_names()->count("get_kv_io_bit_width") > 0) {
kv_bitwidth = static_cast<example::KvBitWidth>(
module->get("get_kv_io_bit_width").get().toScalar().to<int64_t>());
}

if (kv_bitwidth == example::KvBitWidth::kWidth8) {
start_runner<uint8_t>(
std::move(module), prompts, std::move(attention_sink_rope_module));
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
start_runner<uint16_t>(
std::move(module), prompts, std::move(attention_sink_rope_module));
} else {
ET_CHECK_MSG(
false,
"Unsupported kv bitwidth: %ld",
static_cast<int64_t>(kv_bitwidth));
}
start_runner(
std::move(module), prompts, std::move(attention_sink_rope_module));

return 0;
}
38 changes: 7 additions & 31 deletions examples/qualcomm/oss_scripts/llama/qnn_multimodal_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ std::vector<std::string> CollectPrompts(int argc, char** argv) {
return prompts;
}

template <typename T>
void start_multimodal_runner(
std::unique_ptr<executorch::extension::Module> encoder,
std::unique_ptr<executorch::extension::Module> tok_embedding,
Expand All @@ -150,7 +149,7 @@ void start_multimodal_runner(
: true;

// Create multimodal runner
example::QNNMultimodalRunner<T> runner(
example::QNNMultimodalRunner runner(
std::move(encoder),
std::move(tok_embedding),
std::move(text_decoder),
Expand Down Expand Up @@ -289,35 +288,12 @@ int main(int argc, char** argv) {
FLAGS_decoder_path.c_str(),
executorch::extension::Module::LoadMode::MmapUseMlockIgnoreErrors);

// Using 8bit as default since this meta is introduced with 16bit kv io
// support and older models only have 8bit kv io.
example::KvBitWidth kv_bitwidth = example::KvBitWidth::kWidth8;
if (text_decoder->method_names()->count("get_kv_io_bit_width") > 0) {
kv_bitwidth = static_cast<example::KvBitWidth>(
text_decoder->get("get_kv_io_bit_width")
.get()
.toScalar()
.to<int64_t>());
}
// Start runner with appropriate KV bitwidth
if (kv_bitwidth == example::KvBitWidth::kWidth8) {
start_multimodal_runner<uint8_t>(
std::move(encoder),
std::move(tok_embedding),
std::move(text_decoder),
prompts);
} else if (kv_bitwidth == example::KvBitWidth::kWidth16) {
start_multimodal_runner<uint16_t>(
std::move(encoder),
std::move(tok_embedding),
std::move(text_decoder),
prompts);
} else {
ET_CHECK_MSG(
false,
"Unsupported kv bitwidth: %ld",
static_cast<int64_t>(kv_bitwidth));
}
// Start runner
start_multimodal_runner(
std::move(encoder),
std::move(tok_embedding),
std::move(text_decoder),
prompts);

return 0;
}
28 changes: 23 additions & 5 deletions examples/qualcomm/oss_scripts/llama/runner/decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#pragma once

#include <executorch/examples/qualcomm/oss_scripts/llama/runner/utils.h>
#include <executorch/extension/llm/sampler/sampler.h>
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor.h>
Expand Down Expand Up @@ -56,19 +57,36 @@ class DecoderRunner {
inline int32_t logits_to_token(
const executorch::aten::Tensor& logits_tensor,
int64_t pos) {
auto* logits = logits_tensor.mutable_data_ptr<uint16_t>();
std::byte* logits = logits_tensor.mutable_data_ptr<std::byte>();
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
static std::vector<float> logits_f(vocab_size);
auto* logits_last = logits;
std::byte* logits_last = logits;
// offset to the meaningful logit we want for prefill model.
executorch::aten::ScalarType logits_dtype = logits_tensor.scalar_type();
size_t logits_nbytes = getDtypeSize(logits_dtype);
if (num_tokens > 1) {
logits_last += pos * vocab_size;
logits_last += pos * vocab_size * logits_nbytes;
}
// Discard dequantization (converting uint16_t to float) because the
// Discard dequantization (converting std::byte to float) because the
// relative order of elements remains the same without conversion
for (int i = 0; i < vocab_size; i++) {
logits_f[i] = logits_last[i];
switch (logits_dtype) {
case executorch::aten::ScalarType::UInt16:
logits_f[i] = reinterpret_cast<uint16_t*>(logits_last)[i];
break;
case executorch::aten::ScalarType::Byte:
logits_f[i] = reinterpret_cast<uint8_t*>(logits_last)[i];
break;
case executorch::aten::ScalarType::Float:
logits_f[i] = reinterpret_cast<float*>(logits_last)[i];
break;
default:
ET_CHECK_MSG(
false,
"The scalar_type %s of logits is not supported",
executorch::runtime::toString(logits_dtype));
}
}
return sampler_->sample(logits_f.data());
}
Expand Down
Loading
Loading