From 6b462fc614d8d466351678ef5a76ef535b642302 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 5 Feb 2026 10:21:13 -0800 Subject: [PATCH 1/3] [ET-VK][testing] Improve benchmark harness with reference caching and per-shader timing Pull Request resolved: https://github.com/pytorch/executorch/pull/17105 This change improves the benchmark test harness in three ways: 1. **Reference computation caching**: Test cases are now grouped by a `ReferenceKey` that captures the inputs affecting reference output (sizes, dtype, data generation type). Reference computation runs once per group and results are reused, significantly speeding up test suites with many storage/layout variations of the same logical test case. 2. **Per-shader timing breakdown**: Benchmark output now shows individual shader execution times with global and local workgroup sizes, making it easier to identify performance bottlenecks when multiple shaders participate in an operator. 3. **Deferred data generation**: Tensor data is now generated lazily with explicit seeding, enabling deterministic data sharing across grouped test cases. This ensures identical inputs produce identical reference outputs for caching correctness. Also adds string input support (`ValueSpec::make_string()`) and helper functions for concise test case naming (`layout_abbrev`, `repr_str`, `shape_string`). ghstack-source-id: 338638546 @exported-using-ghexport Differential Revision: [D91945038](https://our.internmc.facebook.com/intern/diff/D91945038/) --- backends/vulkan/test/custom_ops/utils.cpp | 649 +++++++++++++++++----- backends/vulkan/test/custom_ops/utils.h | 178 +++++- 2 files changed, 668 insertions(+), 159 deletions(-) diff --git a/backends/vulkan/test/custom_ops/utils.cpp b/backends/vulkan/test/custom_ops/utils.cpp index 307e7d562b9..b23c288a58f 100644 --- a/backends/vulkan/test/custom_ops/utils.cpp +++ b/backends/vulkan/test/custom_ops/utils.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include @@ -21,37 +22,55 @@ int get_seed() { return seed++; } +int get_seed_or_explicit(int explicit_seed) { + if (explicit_seed >= 0) { + return explicit_seed; + } + return get_seed(); +} + // Forward declarations for data generation utilities void generate_random_float_data( std::vector& data, float min_val = -1.0f, - float max_val = 1.0f); + float max_val = 1.0f, + int explicit_seed = -1); void generate_random_int_data( std::vector& data, int min_val = -10, - int max_val = 10); + int max_val = 10, + int explicit_seed = -1); void generate_randint_float_data( std::vector& data, int min_val = -10, - int max_val = 10); + int max_val = 10, + int explicit_seed = -1); void generate_randint_half_data( std::vector& data, int min_val = -10, - int max_val = 10); + int max_val = 10, + int explicit_seed = -1); void generate_random_int8_data( std::vector& data, int8_t min_val = -10, - int8_t max_val = 10); + int8_t max_val = 10, + int explicit_seed = -1); void generate_random_uint8_data( std::vector& data, uint8_t min_val = 0, - uint8_t max_val = 255); -void generate_random_2xint4_data(std::vector& data); -void generate_random_2xint4_data(std::vector& data); + uint8_t max_val = 255, + int explicit_seed = -1); +void generate_random_2xint4_data( + std::vector& data, + int explicit_seed = -1); +void generate_random_2xint4_data( + std::vector& data, + int explicit_seed = -1); void generate_random_int4_data( std::vector& data, int8_t min_val = -8, - int8_t max_val = 7); + int8_t max_val = 7, + int explicit_seed = -1); void generate_ones_data(std::vector& data); void generate_zeros_data(std::vector& data); @@ -96,7 +115,7 @@ void set_debugging(bool enable_debugging) { } // ValueSpec implementation -void ValueSpec::generate_tensor_data() { +void ValueSpec::generate_tensor_data(int seed) { if (spec_type != SpecType::Tensor) { return; } @@ -107,15 +126,15 @@ void ValueSpec::generate_tensor_data() { case vkapi::kFloat: { float_data.resize(num_elements); if (data_gen_type == DataGenType::RANDOM) { - generate_random_float_data(float_data); + generate_random_float_data(float_data, -1.0f, 1.0f, seed); } else if (data_gen_type == DataGenType::RANDOM_SCALES) { - generate_random_float_data(float_data, 0.005, 0.015); + generate_random_float_data(float_data, 0.005, 0.015, seed); } else if (data_gen_type == DataGenType::RANDINT) { - generate_randint_float_data(float_data); + generate_randint_float_data(float_data, -10, 10, seed); } else if (data_gen_type == DataGenType::RANDINT8) { - generate_randint_float_data(float_data, -128, 127); + generate_randint_float_data(float_data, -128, 127, seed); } else if (data_gen_type == DataGenType::RANDINT4) { - generate_randint_float_data(float_data, -8, 7); + generate_randint_float_data(float_data, -8, 7, seed); } else if (data_gen_type == DataGenType::ONES) { generate_ones_data(float_data); } else if (data_gen_type == DataGenType::ZEROS) { @@ -130,17 +149,17 @@ void ValueSpec::generate_tensor_data() { if (data_gen_type == DataGenType::RANDOM) { // Generate random float data first, then convert to half std::vector temp_data(num_elements); - generate_random_float_data(temp_data); + generate_random_float_data(temp_data, -1.0f, 1.0f, seed); for (size_t i = 0; i < temp_data.size(); ++i) { // Simple conversion to uint16_t representation of half half_data[i] = static_cast(temp_data[i] * 32767.0f); } } else if (data_gen_type == DataGenType::RANDINT) { - generate_randint_half_data(half_data); + generate_randint_half_data(half_data, -10, 10, seed); } else if (data_gen_type == DataGenType::RANDINT8) { - generate_randint_half_data(half_data, -128, 127); + generate_randint_half_data(half_data, -128, 127, seed); } else if (data_gen_type == DataGenType::RANDINT4) { - generate_randint_half_data(half_data, -8, 7); + generate_randint_half_data(half_data, -8, 7, seed); } else if (data_gen_type == DataGenType::ONES) { std::fill( half_data.begin(), @@ -162,14 +181,17 @@ void ValueSpec::generate_tensor_data() { case vkapi::kInt: { int32_data.resize(num_elements); if (data_gen_type == DataGenType::RANDOM) { - generate_random_int_data(int32_data); + generate_random_int_data(int32_data, -10, 10, seed); } else if (data_gen_type == DataGenType::RANDINT) { generate_random_int_data( - int32_data); // For int type, RANDINT is same as RANDOM + int32_data, + -10, + 10, + seed); // For int type, RANDINT is same as RANDOM } else if (data_gen_type == DataGenType::RANDINT8) { - generate_random_int_data(int32_data, -128, 127); + generate_random_int_data(int32_data, -128, 127, seed); } else if (data_gen_type == DataGenType::RANDINT4) { - generate_random_int_data(int32_data, -8, 7); + generate_random_int_data(int32_data, -8, 7, seed); } else if (data_gen_type == DataGenType::ONES) { std::fill(int32_data.begin(), int32_data.end(), 1); } else if (data_gen_type == DataGenType::ZEROS) { @@ -182,13 +204,13 @@ void ValueSpec::generate_tensor_data() { case vkapi::kChar: { int8_data.resize(num_elements); if (data_gen_type == DataGenType::RANDOM) { - generate_random_int8_data(int8_data); + generate_random_int8_data(int8_data, -10, 10, seed); } else if (data_gen_type == DataGenType::RANDINT) { - generate_random_int8_data(int8_data); + generate_random_int8_data(int8_data, -10, 10, seed); } else if (data_gen_type == DataGenType::RANDINT8) { - generate_random_int8_data(int8_data, -128, 127); + generate_random_int8_data(int8_data, -128, 127, seed); } else if (data_gen_type == DataGenType::RANDINT4) { - generate_random_2xint4_data(int8_data); + generate_random_2xint4_data(int8_data, seed); } else if (data_gen_type == DataGenType::ONES) { std::fill(int8_data.begin(), int8_data.end(), 1); } else if (data_gen_type == DataGenType::ONES_INT4) { @@ -204,13 +226,13 @@ void ValueSpec::generate_tensor_data() { case vkapi::kByte: { uint8_data.resize(num_elements); if (data_gen_type == DataGenType::RANDOM) { - generate_random_uint8_data(uint8_data); + generate_random_uint8_data(uint8_data, 0, 255, seed); } else if (data_gen_type == DataGenType::RANDINT) { - generate_random_uint8_data(uint8_data); + generate_random_uint8_data(uint8_data, 0, 255, seed); } else if (data_gen_type == DataGenType::RANDINT8) { - generate_random_uint8_data(uint8_data, 0, 255); + generate_random_uint8_data(uint8_data, 0, 255, seed); } else if (data_gen_type == DataGenType::RANDINT4) { - generate_random_2xint4_data(uint8_data); + generate_random_2xint4_data(uint8_data, seed); } else if (data_gen_type == DataGenType::ONES) { std::fill(uint8_data.begin(), uint8_data.end(), 1); } else if (data_gen_type == DataGenType::ONES_INT4) { @@ -227,9 +249,9 @@ void ValueSpec::generate_tensor_data() { // Default to float float_data.resize(num_elements); if (data_gen_type == DataGenType::RANDOM) { - generate_random_float_data(float_data); + generate_random_float_data(float_data, -1.0f, 1.0f, seed); } else if (data_gen_type == DataGenType::RANDINT) { - generate_randint_float_data(float_data); + generate_randint_float_data(float_data, -10, 10, seed); } else if (data_gen_type == DataGenType::ONES) { generate_ones_data(float_data); } else if (data_gen_type == DataGenType::ZEROS) { @@ -316,6 +338,11 @@ std::string ValueSpec::to_string() const { result += (data_gen_type == DataGenType::FIXED) ? "FIXED" : "RANDOM"; result += ")"; return result; + case SpecType::String: + result += "type=String, value=\""; + result += get_string_value(); + result += "\")"; + return result; } for (size_t i = 0; i < sizes.size(); ++i) { @@ -494,8 +521,9 @@ const void* ValueSpec::get_data_ptr() const { void generate_random_float_data( std::vector& data, float min_val, - float max_val) { - std::mt19937 gen(get_seed()); + float max_val, + int explicit_seed) { + std::mt19937 gen(get_seed_or_explicit(explicit_seed)); std::uniform_real_distribution dis(min_val, max_val); for (auto& val : data) { val = dis(gen); @@ -505,8 +533,9 @@ void generate_random_float_data( void generate_random_int_data( std::vector& data, int min_val, - int max_val) { - std::mt19937 gen(get_seed()); + int max_val, + int explicit_seed) { + std::mt19937 gen(get_seed_or_explicit(explicit_seed)); std::uniform_int_distribution dis(min_val, max_val); for (auto& val : data) { val = dis(gen); @@ -516,8 +545,9 @@ void generate_random_int_data( void generate_randint_float_data( std::vector& data, int min_val, - int max_val) { - std::mt19937 gen(get_seed()); + int max_val, + int explicit_seed) { + std::mt19937 gen(get_seed_or_explicit(explicit_seed)); std::uniform_int_distribution dis(min_val, max_val); for (auto& val : data) { val = static_cast(dis(gen)); @@ -527,8 +557,9 @@ void generate_randint_float_data( void generate_randint_half_data( std::vector& data, int min_val, - int max_val) { - std::mt19937 gen(get_seed()); + int max_val, + int explicit_seed) { + std::mt19937 gen(get_seed_or_explicit(explicit_seed)); std::uniform_int_distribution dis(min_val, max_val); for (auto& val : data) { val = static_cast(std::abs(dis(gen)) % 65536); @@ -542,8 +573,9 @@ void generate_ones_data(std::vector& data) { void generate_random_int8_data( std::vector& data, int8_t min_val, - int8_t max_val) { - std::mt19937 gen(get_seed()); + int8_t max_val, + int explicit_seed) { + std::mt19937 gen(get_seed_or_explicit(explicit_seed)); std::uniform_int_distribution dis(min_val, max_val); for (auto& val : data) { val = static_cast(dis(gen)); @@ -553,8 +585,9 @@ void generate_random_int8_data( void generate_random_uint8_data( std::vector& data, uint8_t min_val, - uint8_t max_val) { - std::mt19937 gen(get_seed()); + uint8_t max_val, + int explicit_seed) { + std::mt19937 gen(get_seed_or_explicit(explicit_seed)); std::uniform_int_distribution dis(min_val, max_val); for (auto& val : data) { val = static_cast(dis(gen)); @@ -564,16 +597,17 @@ void generate_random_uint8_data( void generate_random_int4_data( std::vector& data, int8_t min_val, - int8_t max_val) { - std::mt19937 gen(get_seed()); + int8_t max_val, + int explicit_seed) { + std::mt19937 gen(get_seed_or_explicit(explicit_seed)); std::uniform_int_distribution dis(min_val, max_val); for (auto& val : data) { val = static_cast(dis(gen)); } } -void generate_random_2xint4_data(std::vector& data) { - std::mt19937 gen(get_seed()); +void generate_random_2xint4_data(std::vector& data, int explicit_seed) { + std::mt19937 gen(get_seed_or_explicit(explicit_seed)); std::uniform_int_distribution dis(-8, 7); // Signed 4-bit range for (auto& val : data) { // Generate two separate 4-bit values @@ -584,8 +618,10 @@ void generate_random_2xint4_data(std::vector& data) { } } -void generate_random_2xint4_data(std::vector& data) { - std::mt19937 gen(get_seed()); +void generate_random_2xint4_data( + std::vector& data, + int explicit_seed) { + std::mt19937 gen(get_seed_or_explicit(explicit_seed)); std::uniform_int_distribution dis(0, 15); // Unsigned 4-bit range for (auto& val : data) { // Generate two separate 4-bit values @@ -652,6 +688,88 @@ bool ValueSpec::validate_against_reference( return true; } +// Ensure data is generated for this ValueSpec +void ValueSpec::ensure_data_generated(int seed) { + if (data_generated_) { + return; + } + generate_tensor_data(seed); + data_generated_ = true; +} + +// Copy input data from another ValueSpec +void ValueSpec::copy_data_from(const ValueSpec& other) { + if (!is_tensor() || !other.is_tensor()) { + return; + } + // Copy raw data based on dtype + float_data = other.float_data; + int32_data = other.int32_data; + half_data = other.half_data; + int8_data = other.int8_data; + uint8_data = other.uint8_data; + data_generated_ = other.data_generated_; +} + +// ReferenceKey implementation +ReferenceKey ReferenceKey::from_test_case(const TestCase& tc) { + std::ostringstream oss; + + // Serialize inputs that affect reference computation + // Skip: storage_type, memory_layout, string values (like impl_selector) + for (size_t i = 0; i < tc.inputs().size(); ++i) { + const ValueSpec& input = tc.inputs()[i]; + oss << "i" << i << ":"; + + if (input.is_tensor()) { + // For tensors: sizes, dtype, data_gen_type, is_constant + oss << "T["; + for (size_t j = 0; j < input.sizes.size(); ++j) { + if (j > 0) + oss << ","; + oss << input.sizes[j]; + } + oss << "]d" << static_cast(input.dtype); + oss << "g" << static_cast(input.data_gen_type); + oss << "c" << (input.is_constant() ? 1 : 0); + oss << "n" << (input.is_none() ? 1 : 0); + } else if (input.is_int()) { + oss << "I" << input.get_int_value(); + } else if (input.is_float()) { + oss << "F" << input.get_float_value(); + } else if (input.is_bool()) { + oss << "B" << (input.get_bool_value() ? 1 : 0); + } else if (input.is_int_list()) { + oss << "L["; + const auto& list = input.get_int_list(); + for (size_t j = 0; j < list.size(); ++j) { + if (j > 0) + oss << ","; + oss << list[j]; + } + oss << "]"; + } + // Skip string inputs (like impl_selector) as they don't affect reference + oss << ";"; + } + + // Also include output shapes for completeness + for (size_t i = 0; i < tc.outputs().size(); ++i) { + const ValueSpec& output = tc.outputs()[i]; + oss << "o" << i << ":["; + for (size_t j = 0; j < output.sizes.size(); ++j) { + if (j > 0) + oss << ","; + oss << output.sizes[j]; + } + oss << "]d" << static_cast(output.dtype) << ";"; + } + + ReferenceKey key; + key.key_string = oss.str(); + return key; +} + // Helper function to collect GPU timing from querypool float collect_gpu_timing_us( ComputeGraph& graph, @@ -685,11 +803,68 @@ float collect_gpu_timing_us( return 0.0f; } +// Helper function to collect per-shader GPU timing from querypool +// Returns a map of shader_name -> timing_us for non-filtered shaders +std::unordered_map collect_per_shader_timing_us( + ComputeGraph& graph, + const std::vector& shader_filter) { + std::unordered_map shader_timings; + + graph.context()->querypool().extract_results(); + const auto results = graph.context()->querypool().get_shader_timestamp_data(); + for (const auto& shader_result : results) { + bool filtered = false; + // Check if this shader matches any filter pattern + for (const auto& filter_pattern : shader_filter) { + if (shader_result.kernel_name.find(filter_pattern) != std::string::npos) { + filtered = true; + break; + } + } + + if (!filtered) { + // Calculate duration from start and end times, convert from ns to μs + uint64_t duration_ns = + shader_result.end_time_ns - shader_result.start_time_ns; + float duration_us = static_cast(duration_ns) / 1000.0f; + // Accumulate timing for shaders with the same name + shader_timings[shader_result.kernel_name] += duration_us; + } + } + return shader_timings; +} + // BenchmarkResult implementation void BenchmarkResult::add_iter_timing(float time_us) { iter_timings.push_back(time_us); } +void BenchmarkResult::add_shader_timing( + const std::string& shader_name, + float time_us, + const uint32_t global_wg[3], + const uint32_t local_wg[3]) { + // Find existing shader timing or create new one + for (auto& st : shader_timings_) { + if (st.shader_name == shader_name) { + st.iter_timings_us.push_back(time_us); + // Work group sizes should be consistent across iterations + return; + } + } + // Not found, create new entry + ShaderTiming new_timing; + new_timing.shader_name = shader_name; + new_timing.iter_timings_us.push_back(time_us); + new_timing.global_wg_size[0] = global_wg[0]; + new_timing.global_wg_size[1] = global_wg[1]; + new_timing.global_wg_size[2] = global_wg[2]; + new_timing.local_wg_size[0] = local_wg[0]; + new_timing.local_wg_size[1] = local_wg[1]; + new_timing.local_wg_size[2] = local_wg[2]; + shader_timings_.push_back(std::move(new_timing)); +} + float BenchmarkResult::get_avg_time_us() const { if (iter_timings.empty()) { return 0.0f; @@ -739,11 +914,27 @@ void BenchmarkResult::print_summary( const std::string& size_info, float total_gflops) const { static constexpr int OPERATOR_NAME_WIDTH = 50; - static constexpr int KERNEL_NAME_WIDTH = 70; + static constexpr int GLOBAL_WG_WIDTH = 16; + static constexpr int LOCAL_WG_WIDTH = 12; + static constexpr int KERNEL_NAME_WIDTH = 80; static constexpr int SIZE_INFO_WIDTH = 20; - static constexpr int TIMING_WIDTH = 20; - static constexpr int GFLOPS_WIDTH = 20; - static constexpr int CORRECTNESS_WIDTH = 10; + static constexpr int TIMING_WIDTH = 16; + static constexpr int GFLOPS_WIDTH = 14; + static constexpr int CORRECTNESS_WIDTH = 8; + + // Helper to truncate shader names longer than 46 chars to 44 chars + ".." + auto truncate_shader_name = [](const std::string& name) -> std::string { + if (name.length() > 46) { + return name.substr(0, 44) + ".."; + } + return name; + }; + + // Helper to format work group size as (x,y,z) + auto format_wg_size = [](const uint32_t wg[3]) -> std::string { + return "(" + std::to_string(wg[0]) + "," + std::to_string(wg[1]) + "," + + std::to_string(wg[2]) + ")"; + }; std::string correctness_str; switch (correctness_status_) { @@ -758,14 +949,74 @@ void BenchmarkResult::print_summary( break; } - std::cout << std::left << std::setw(OPERATOR_NAME_WIDTH) - << get_operator_name() << " " << std::left - << std::setw(KERNEL_NAME_WIDTH) << get_kernel_name() << std::right - << " " << std::setw(SIZE_INFO_WIDTH) << size_info - << std::setw(TIMING_WIDTH) << std::fixed << std::setprecision(3) - << get_avg_time_us() << " μs " << std::setw(GFLOPS_WIDTH) - << std::fixed << std::setprecision(3) << total_gflops << " GFLOP/s " - << std::setw(CORRECTNESS_WIDTH) << correctness_str << std::endl; + // If we have per-shader timing data, print one line per shader plus overall + if (!shader_timings_.empty()) { + // If only one shader, print a single combined row + if (shader_timings_.size() == 1) { + const auto& st = shader_timings_[0]; + std::cout << std::left << std::setw(OPERATOR_NAME_WIDTH) + << truncate_shader_name(st.shader_name) << " " << std::left + << std::setw(GLOBAL_WG_WIDTH) + << format_wg_size(st.global_wg_size) << std::left + << std::setw(LOCAL_WG_WIDTH) << format_wg_size(st.local_wg_size) + << std::left << std::setw(KERNEL_NAME_WIDTH) + << get_kernel_name() << std::right << " " + << std::setw(SIZE_INFO_WIDTH) << size_info + << std::setw(TIMING_WIDTH) << std::fixed << std::setprecision(3) + << get_avg_time_us() << " μs " << std::setw(GFLOPS_WIDTH) + << std::fixed << std::setprecision(3) << total_gflops + << " GFLOP/s " << std::setw(CORRECTNESS_WIDTH) + << correctness_str << std::endl; + } else { + // Multiple shaders: print individual shader lines (without GFLOP/s) + for (size_t i = 0; i < shader_timings_.size(); ++i) { + const auto& st = shader_timings_[i]; + float shader_avg_time = st.get_avg_time_us(); + + // Shader lines don't show test case info + std::cout << std::left << std::setw(OPERATOR_NAME_WIDTH) + << truncate_shader_name(st.shader_name) << " " << std::left + << std::setw(GLOBAL_WG_WIDTH) + << format_wg_size(st.global_wg_size) << std::left + << std::setw(LOCAL_WG_WIDTH) + << format_wg_size(st.local_wg_size) << std::left + << std::setw(KERNEL_NAME_WIDTH) << "" << std::right << " " + << std::setw(SIZE_INFO_WIDTH) << "" << std::setw(TIMING_WIDTH) + << std::fixed << std::setprecision(3) << shader_avg_time + << " μs " << std::setw(GFLOPS_WIDTH) << "" << " " + << std::setw(CORRECTNESS_WIDTH) << "" << std::endl; + } + + // Print overall row with operator name, test case info, total time, and + // GFLOP/s + std::cout << std::left << std::setw(OPERATOR_NAME_WIDTH) + << get_operator_name() << " " << std::left + << std::setw(GLOBAL_WG_WIDTH) << "" << std::left + << std::setw(LOCAL_WG_WIDTH) << "" << std::left + << std::setw(KERNEL_NAME_WIDTH) << get_kernel_name() + << std::right << " " << std::setw(SIZE_INFO_WIDTH) << size_info + << std::setw(TIMING_WIDTH) << std::fixed << std::setprecision(3) + << get_avg_time_us() << " μs " << std::setw(GFLOPS_WIDTH) + << std::fixed << std::setprecision(3) << total_gflops + << " GFLOP/s " << std::setw(CORRECTNESS_WIDTH) + << correctness_str << std::endl; + } + + // Print separator line between test cases + } else { + // No per-shader timing data, use the original format + std::cout << std::left << std::setw(OPERATOR_NAME_WIDTH) + << get_operator_name() << " " << std::left + << std::setw(GLOBAL_WG_WIDTH) << "" << std::left + << std::setw(LOCAL_WG_WIDTH) << "" << std::left + << std::setw(KERNEL_NAME_WIDTH) << get_kernel_name() << std::right + << " " << std::setw(SIZE_INFO_WIDTH) << size_info + << std::setw(TIMING_WIDTH) << std::fixed << std::setprecision(3) + << get_avg_time_us() << " μs " << std::setw(GFLOPS_WIDTH) + << std::fixed << std::setprecision(3) << total_gflops + << " GFLOP/s " << std::setw(CORRECTNESS_WIDTH) << correctness_str + << std::endl; + } } // TestResult implementation @@ -778,7 +1029,7 @@ void TestResult::add_result(BenchmarkResult&& result) { } void TestResult::print_summary() const { - static constexpr int CASE_WIDTH = 80; + static constexpr int CASE_WIDTH = 100; static constexpr int KERNEL_NAME_WIDTH = 20; static constexpr int TIMING_WIDTH = 12; static constexpr int PASS_WIDTH = 8; @@ -1069,6 +1320,10 @@ ComputeGraph setup_compute_graph(TestCase& test_case, std::string op_name) { } ValueRef input_value = graph.add_scalar_list(std::move(int64_list)); input_values.push_back(input_value); + } else if (input_spec.is_string()) { + std::string str_copy = input_spec.get_string_value(); + ValueRef input_value = graph.add_string(std::move(str_copy)); + input_values.push_back(input_value); } else if (input_spec.is_constant()) { ValueRef input_value = graph.add_tensorref( input_spec.get_tensor_sizes(), @@ -1200,9 +1455,38 @@ execute_test_case(TestCase& test_case, int warmup_runs, int benchmark_runs) { float cpu_time_us = static_cast(cpu_duration.count()); total_cpu_time_us += cpu_time_us; - // Collect GPU timing using helper function - float gpu_time_us = - collect_gpu_timing_us(graph, test_case.get_shader_filter()); + // Collect per-shader GPU timing - get raw shader results to preserve + // metadata + graph.context()->querypool().extract_results(); + const auto shader_results = + graph.context()->querypool().get_shader_timestamp_data(); + + // Calculate total GPU time from per-shader timings + float gpu_time_us = 0.0f; + for (const auto& shader_result : shader_results) { + // Check if this shader matches any filter pattern + bool filtered = false; + for (const auto& filter_pattern : test_case.get_shader_filter()) { + if (shader_result.kernel_name.find(filter_pattern) != + std::string::npos) { + filtered = true; + break; + } + } + + if (!filtered) { + uint64_t duration_ns = + shader_result.end_time_ns - shader_result.start_time_ns; + float duration_us = static_cast(duration_ns) / 1000.0f; + gpu_time_us += duration_us; + // Store per-shader timing with work group sizes + result.add_shader_timing( + shader_result.kernel_name, + duration_us, + shader_result.metadata.global_workgroup_size, + shader_result.metadata.local_workgroup_size); + } + } total_gpu_time_us += gpu_time_us; // Add the appropriate timing based on the flag @@ -1274,110 +1558,177 @@ TestResult execute_test_cases( << operation_name << std::endl; print_separator(); + // Group test cases by ReferenceKey for caching reference computations + // Use a vector to preserve the order in which groups first appear + std::vector group_order; + std::unordered_map, ReferenceKeyHash> + groups; + for (size_t i = 0; i < test_cases.size(); ++i) { + ReferenceKey key = ReferenceKey::from_test_case(test_cases[i]); + if (groups.find(key) == groups.end()) { + group_order.push_back(key); + } + groups[key].push_back(i); + } + bool any_correctness_failed = false; float total_gflops = 0.0f; + size_t test_case_counter = 0; + + // Process each group: generate data, compute reference, execute, and print + // Iterate in the order groups first appeared in test_cases + for (const auto& key : group_order) { + const auto& indices = groups[key]; + if (indices.empty()) + continue; + + // Get first test case as the "prototype" + size_t prototype_idx = indices[0]; + TestCase& prototype = test_cases[prototype_idx]; + + // Generate data for prototype with deterministic seed based on key + int group_seed = + static_cast(std::hash{}(key.key_string) % 10000); + for (auto& input : prototype.inputs()) { + input.ensure_data_generated(group_seed++); + } - for (size_t i = 0; i < test_cases.size(); ++i) { - TestCase& test_case = test_cases[i]; - - // Compute reference data if reference function is provided - bool skipped_reference_fn = true; + // Compute reference once for prototype + bool ref_computed = false; + std::vector> ref_data; if (reference_compute_func) { try { - reference_compute_func(test_case); - skipped_reference_fn = false; - } catch (const std::invalid_argument& e) { - if (debugging()) { - std::cout << "Compute reference skipped: " << e.what() << std::endl; + reference_compute_func(prototype); + ref_computed = true; + + // Cache the reference output for this group + for (const auto& output : prototype.outputs()) { + ref_data.push_back(output.get_ref_float_data()); } + } catch (const std::invalid_argument& _) { + // Reference computation skipped for this group } } - // Execute single test case - BenchmarkResult result; - bool shader_not_supported = false; - try { - result = execute_test_case(test_case, warmup_runs, benchmark_runs); - result.set_operator_name(test_case.operator_name()); - } catch (const vkcompute::vkapi::ShaderNotSupportedError&) { - result = BenchmarkResult( - test_case.name().empty() ? "unnamed_test_case" : test_case.name(), - test_case.operator_name()); - shader_not_supported = true; + // Copy data and reference to other test cases in group + for (size_t i = 1; i < indices.size(); ++i) { + size_t tc_idx = indices[i]; + TestCase& tc = test_cases[tc_idx]; + + // Copy input data from prototype + for (size_t j = 0; + j < tc.inputs().size() && j < prototype.inputs().size(); + ++j) { + auto& dest = tc.inputs()[j]; + const auto& src = prototype.inputs()[j]; + if (dest.is_tensor() && src.is_tensor() && dest.sizes == src.sizes && + dest.dtype == src.dtype) { + dest.copy_data_from(src); + } + } + + // Copy reference output data if available + if (ref_computed) { + for (size_t j = 0; j < tc.outputs().size() && j < ref_data.size(); + ++j) { + tc.outputs()[j].get_ref_float_data() = ref_data[j]; + } + } } - // Determine if this test case passed (has valid timing data) - bool vulkan_execute_succeeded = - result.get_num_iterations() > 0 && result.get_avg_time_us() > 0.0f; + // Execute and print results for all test cases in this group + for (size_t tc_idx : indices) { + TestCase& test_case = test_cases[tc_idx]; + ++test_case_counter; - if (shader_not_supported) { - result.set_correctness_status(CorrectnessStatus::SKIPPED); - } else if (!vulkan_execute_succeeded) { - result.set_correctness_status(CorrectnessStatus::FAILED); - } else if (skipped_reference_fn) { - result.set_correctness_status(CorrectnessStatus::SKIPPED); - } else { - // Reference function provided and succeeded - validate outputs - bool correctness_passed = true; - - for (size_t output_idx = 0; output_idx < test_case.num_outputs(); - ++output_idx) { - const ValueSpec& output_spec = test_case.outputs()[output_idx]; - - if (!output_spec.validate_against_reference( - test_case.get_abs_tolerance(), test_case.get_rel_tolerance())) { - correctness_passed = false; - std::cout << " Correctness validation FAILED for test " - << result.get_kernel_name() << std::endl; - print_valuespec_data(output_spec, "vulkan output"); - print_valuespec_data(output_spec, "ref output", true); - - throw std::runtime_error("Correctness validation failed"); - } + // Execute single test case + BenchmarkResult result; + bool shader_not_supported = false; + try { + result = execute_test_case(test_case, warmup_runs, benchmark_runs); + result.set_operator_name(test_case.operator_name()); + } catch (const vkcompute::vkapi::ShaderNotSupportedError&) { + result = BenchmarkResult( + test_case.name().empty() ? "unnamed_test_case" : test_case.name(), + test_case.operator_name()); + shader_not_supported = true; } - if (correctness_passed) { - result.set_correctness_status(CorrectnessStatus::PASSED); - } else { - any_correctness_failed = true; + // Determine if this test case passed (has valid timing data) + bool vulkan_execute_succeeded = + result.get_num_iterations() > 0 && result.get_avg_time_us() > 0.0f; + + if (shader_not_supported) { + result.set_correctness_status(CorrectnessStatus::SKIPPED); + } else if (!vulkan_execute_succeeded) { result.set_correctness_status(CorrectnessStatus::FAILED); - } - } + } else if (!ref_computed) { + result.set_correctness_status(CorrectnessStatus::SKIPPED); + } else { + // Reference function provided and succeeded - validate outputs + bool correctness_passed = true; + + for (size_t output_idx = 0; output_idx < test_case.num_outputs(); + ++output_idx) { + const ValueSpec& output_spec = test_case.outputs()[output_idx]; + + if (!output_spec.validate_against_reference( + test_case.get_abs_tolerance(), + test_case.get_rel_tolerance())) { + correctness_passed = false; + std::cout << " Correctness validation FAILED for test " + << result.get_kernel_name() << std::endl; + print_valuespec_data(output_spec, "vulkan output"); + print_valuespec_data(output_spec, "ref output", true); + + throw std::runtime_error("Correctness validation failed"); + } + } - // Calculate GFLOPS for this test case using the provided FLOP calculator - float case_gflops = 0.0f; - if (vulkan_execute_succeeded) { - // Use the provided FLOP calculator to get total FLOPs for this test case - int64_t total_flops = flop_calculator(test_case); - float flops = static_cast(total_flops); - float avg_time_us = result.get_avg_time_us(); - if (avg_time_us > 0.0f && total_flops > 0) { - case_gflops = (flops / 1e9f) / (avg_time_us / 1e6f); + if (correctness_passed) { + result.set_correctness_status(CorrectnessStatus::PASSED); + } else { + any_correctness_failed = true; + result.set_correctness_status(CorrectnessStatus::FAILED); + } } - total_gflops += case_gflops; - } else { - case_gflops = -1.0f; // Indicate failure - } + // Calculate GFLOPS for this test case using the provided FLOP calculator + float case_gflops = 0.0f; + if (vulkan_execute_succeeded) { + // Use the provided FLOP calculator to get total FLOPs for this test + // case + int64_t total_flops = flop_calculator(test_case); + float flops = static_cast(total_flops); + float avg_time_us = result.get_avg_time_us(); + if (avg_time_us > 0.0f && total_flops > 0) { + case_gflops = (flops / 1e9f) / (avg_time_us / 1e6f); + } - // Calculate tensor info for display - std::string size_info = "["; - if (!test_case.empty() && test_case.num_inputs() > 0 && - test_case.inputs()[0].is_tensor()) { - const auto& sizes = test_case.inputs()[0].get_tensor_sizes(); - for (size_t j = 0; j < sizes.size(); ++j) { - size_info += std::to_string(sizes[j]); - if (j < sizes.size() - 1) - size_info += "x"; + total_gflops += case_gflops; + } else { + case_gflops = -1.0f; // Indicate failure } - } - size_info += "]"; - // Print progress using the BenchmarkResult member function - result.print_summary(i + 1, size_info, case_gflops); + // Calculate tensor info for display + std::string size_info = "["; + if (!test_case.empty() && test_case.num_inputs() > 0 && + test_case.inputs()[0].is_tensor()) { + const auto& sizes = test_case.inputs()[0].get_tensor_sizes(); + for (size_t j = 0; j < sizes.size(); ++j) { + size_info += std::to_string(sizes[j]); + if (j < sizes.size() - 1) + size_info += "x"; + } + } + size_info += "]"; + + // Print progress using the BenchmarkResult member function + result.print_summary(test_case_counter, size_info, case_gflops); - // Add result to collection - results.add_result(std::move(result)); + // Add result to collection + results.add_result(std::move(result)); + } } // Set the overall results on the TestResult diff --git a/backends/vulkan/test/custom_ops/utils.h b/backends/vulkan/test/custom_ops/utils.h index 666b2d2e409..9b5b6a46782 100644 --- a/backends/vulkan/test/custom_ops/utils.h +++ b/backends/vulkan/test/custom_ops/utils.h @@ -14,6 +14,7 @@ #include #include #include +#include #include #include @@ -23,6 +24,29 @@ namespace prototyping { using namespace vkcompute; +// +// ReferenceKey for caching reference computations +// + +// Captures the identity of input conditions for test case grouping. +// Test cases with the same ReferenceKey should produce identical reference +// outputs, so reference computation can be cached and reused. +struct ReferenceKey { + std::string key_string; + + static ReferenceKey from_test_case(const class TestCase& tc); + + bool operator==(const ReferenceKey& other) const { + return key_string == other.key_string; + } +}; + +struct ReferenceKeyHash { + size_t operator()(const ReferenceKey& k) const { + return std::hash{}(k.key_string); + } +}; + // // Global configuration options // @@ -57,11 +81,70 @@ inline const std::vector kLayoutOnlyShaderFilter = { "nchw_to", "to_nchw"}; +// +// String utilities +// + +// Helper function to get abbreviated layout names for test case naming +inline std::string layout_abbrev(utils::GPUMemoryLayout layout) { + switch (layout) { + case utils::kWidthPacked: + return "WP"; + case utils::kChannelsPacked: + return "CP"; + case utils::kPackedInt8_4W: + return "4W"; + case utils::kPackedInt8_4C: + return "4C"; + case utils::kPackedInt8_4W4C: + return "4W4C"; + case utils::kPackedInt8_4H4W: + return "4H4W"; + case utils::kPackedInt8_4C1W: + return "4C1W"; + default: + return "UNK"; + } +} + +// Helper function to get abbreviated storage type names for test case naming +inline std::string storage_type_abbrev(utils::StorageType storage_type) { + switch (storage_type) { + case utils::kTexture3D: + return "Tex"; + case utils::kBuffer: + return "Buf"; + default: + return "UNK"; + } +} + +// Helper function to get combined storage type and layout representation +// Example: (kBuffer, kPackedInt8_4W4C) -> "Buf_4W4C" +inline std::string repr_str( + utils::StorageType storage_type, + utils::GPUMemoryLayout layout) { + return storage_type_abbrev(storage_type) + "(" + layout_abbrev(layout) + ")"; +} + +// Helper function to generate comma-separated shape string for test case naming +// Example: {1, 128, 56, 56} -> "1,128,56,56" +inline std::string shape_string(const std::vector& shape) { + std::string result; + for (size_t i = 0; i < shape.size(); ++i) { + if (i > 0) { + result += ","; + } + result += std::to_string(shape[i]); + } + return result; +} + // // ValueSpec class // -enum class SpecType { Tensor, IntList, Int, Float, Bool }; +enum class SpecType { Tensor, IntList, Int, Float, Bool, String }; // Data generation types enum class DataGenType { @@ -87,12 +170,14 @@ struct ValueSpec { bool is_constant_tensor; bool is_none_flag; bool is_int4_tensor; + bool data_generated_ = false; std::vector float_data; std::vector int32_data; std::vector half_data; // Using uint16_t as substitute for half std::vector int8_data; // For kChar (signed 8-bit) std::vector uint8_data; // For kByte (unsigned 8-bit) + std::string string_data; std::vector ref_float_data; std::vector ref_int32_data; @@ -113,8 +198,9 @@ struct ValueSpec { data_gen_type(DataGenType::ZEROS), is_constant_tensor(false), is_none_flag(false), - is_int4_tensor(false) { - generate_tensor_data(); + is_int4_tensor(false), + data_generated_(false) { + // Data generation is deferred until ensure_data_generated() is called } // Constructor for tensor with custom data generation type @@ -132,8 +218,9 @@ struct ValueSpec { data_gen_type(data_gen_type), is_constant_tensor(false), is_none_flag(false), - is_int4_tensor(false) { - generate_tensor_data(); + is_int4_tensor(false), + data_generated_(false) { + // Data generation is deferred until ensure_data_generated() is called } // Constructor for single int @@ -146,7 +233,8 @@ struct ValueSpec { data_gen_type(DataGenType::FIXED), is_constant_tensor(false), is_none_flag(false), - is_int4_tensor(false) { + is_int4_tensor(false), + data_generated_(true) { int32_data.push_back(value); } @@ -160,7 +248,8 @@ struct ValueSpec { data_gen_type(DataGenType::FIXED), is_constant_tensor(false), is_none_flag(false), - is_int4_tensor(false) { + is_int4_tensor(false), + data_generated_(true) { float_data.push_back(value); } @@ -174,7 +263,8 @@ struct ValueSpec { data_gen_type(DataGenType::FIXED), is_constant_tensor(false), is_none_flag(false), - is_int4_tensor(false) { + is_int4_tensor(false), + data_generated_(true) { int32_data.push_back(value ? 1 : 0); } @@ -189,8 +279,26 @@ struct ValueSpec { is_constant_tensor(false), is_none_flag(false), is_int4_tensor(false), + data_generated_(true), int32_data(values) {} + // Factory method for string (avoids ambiguity with vector constructor) + static ValueSpec make_string(const std::string& value) { + ValueSpec spec; + spec.sizes = {1}; + spec.dtype = vkapi::kInt; + spec.memory_layout = utils::kWidthPacked; + spec.storage_type = utils::kTexture3D; + spec.spec_type = SpecType::String; + spec.data_gen_type = DataGenType::FIXED; + spec.is_constant_tensor = false; + spec.is_none_flag = false; + spec.is_int4_tensor = false; + spec.data_generated_ = true; + spec.string_data = value; + return spec; + } + // Default constructor ValueSpec() : dtype(vkapi::kFloat), @@ -200,7 +308,8 @@ struct ValueSpec { data_gen_type(DataGenType::ZEROS), is_constant_tensor(false), is_none_flag(false), - is_int4_tensor(false) {} + is_int4_tensor(false), + data_generated_(false) {} int64_t numel() const; size_t nbytes() const; @@ -221,6 +330,9 @@ struct ValueSpec { bool is_bool() const { return spec_type == SpecType::Bool; } + bool is_string() const { + return spec_type == SpecType::String; + } int32_t get_int_value() const { return int32_data.empty() ? 0 : int32_data[0]; @@ -231,6 +343,9 @@ struct ValueSpec { bool get_bool_value() const { return int32_data.empty() ? false : (int32_data[0] != 0); } + const std::string& get_string_value() const { + return string_data; + } const std::vector& get_int_list() const { return int32_data; } @@ -306,12 +421,23 @@ struct ValueSpec { void* get_mutable_data_ptr(); float get_element(size_t index) const; + // Data generation methods for deferred generation and caching + bool is_data_generated() const { + return data_generated_; + } + void ensure_data_generated(int seed = -1); + void copy_data_from(const ValueSpec& other); + // Set/get constant flag bool is_constant() const { return is_constant_tensor; } void set_constant(bool is_constant) { is_constant_tensor = is_constant; + // Constant tensors need data immediately for test case setup + if (is_constant && is_tensor()) { + ensure_data_generated(); + } } // Set/get none flag @@ -341,7 +467,7 @@ struct ValueSpec { float rel_tolerance = 1e-3f) const; private: - void generate_tensor_data(); + void generate_tensor_data(int seed = -1); }; // @@ -463,6 +589,25 @@ enum class CorrectnessStatus { FAILED // Reference function provided but validation failed }; +// Per-shader timing data for detailed reporting +struct ShaderTiming { + std::string shader_name; + std::vector iter_timings_us; // Individual iteration timings + uint32_t global_wg_size[3] = {0, 0, 0}; + uint32_t local_wg_size[3] = {0, 0, 0}; + + float get_avg_time_us() const { + if (iter_timings_us.empty()) { + return 0.0f; + } + float sum = 0.0f; + for (float t : iter_timings_us) { + sum += t; + } + return sum / iter_timings_us.size(); + } +}; + class BenchmarkResult { public: BenchmarkResult() : correctness_status_(CorrectnessStatus::SKIPPED) {} @@ -480,6 +625,18 @@ class BenchmarkResult { // Add timing for a single iteration void add_iter_timing(float time_us); + // Add per-shader timing for a single iteration + void add_shader_timing( + const std::string& shader_name, + float time_us, + const uint32_t global_wg[3], + const uint32_t local_wg[3]); + + // Get per-shader timing data + const std::vector& get_shader_timings() const { + return shader_timings_; + } + // Getters const std::string& get_kernel_name() const { return kernel_name; @@ -530,6 +687,7 @@ class BenchmarkResult { std::string operator_name; std::vector iter_timings; // Individual iteration timings in microseconds + std::vector shader_timings_; // Per-shader timing data CorrectnessStatus correctness_status_; }; From e02f43dab0248de982a5e127497137f67ed5f158 Mon Sep 17 00:00:00 2001 From: ssjia Date: Thu, 5 Feb 2026 10:21:15 -0800 Subject: [PATCH 2/3] [ET-VK] Add alignment fields to PackedDimInfo for padded size calculation Pull Request resolved: https://github.com/pytorch/executorch/pull/17170 This change introduces separate alignment fields to PackedDimInfo, decoupling the alignment used for padding tensor dimensions from the block size used for packing. Previously, `calculate_padded_sizes` used `packed_dim_block_size` and `outer_packed_dim_block_size` directly to determine how much to pad tensor dimensions. This works but limits flexibility - there are scenarios where we want to pad dimensions to a larger alignment than the block size for performance reasons, such as ensuring loads are aligned to cache lines or removing the need for bounds checking in shaders. The new fields `packed_dim_align` and `outer_packed_dim_align` allow specifying the alignment independently. For now, these are initialized to match the corresponding block sizes, preserving existing behavior. Future changes can set larger alignment values when beneficial for performance. Authored with Claude. ghstack-source-id: 338638551 @exported-using-ghexport Differential Revision: [D92196649](https://our.internmc.facebook.com/intern/diff/D92196649/) --- .../vulkan/runtime/api/containers/Tensor.cpp | 134 ++++++++++++++---- .../vulkan/runtime/api/containers/Tensor.h | 10 ++ backends/vulkan/runtime/utils/StorageUtils.h | 98 ------------- 3 files changed, 116 insertions(+), 126 deletions(-) diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 4cf949ba5ab..351b920e805 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -17,13 +17,17 @@ namespace api { PackedDimInfo::PackedDimInfo( const int32_t dim, const int32_t dim_block_size, + const int32_t dim_align, const int32_t outer_dim, const int32_t outer_dim_block_size, + const int32_t outer_dim_align, const bool is_block_transposed) : packed_dim(dim), packed_dim_block_size(dim_block_size), + packed_dim_align(dim_align), outer_packed_dim(outer_dim), outer_packed_dim_block_size(outer_dim_block_size), + outer_packed_dim_align(outer_dim_align), block_transposed(is_block_transposed), block_numel(packed_dim_block_size * outer_packed_dim_block_size) { // Packed dims must be different @@ -33,19 +37,97 @@ PackedDimInfo::PackedDimInfo( PackedDimInfo calculate_packed_dim_info( const utils::GPUMemoryLayout memory_layout, const utils::StorageType storage_type) { - const int32_t packed_dim = utils::to_packed_dim(memory_layout); - const int32_t outer_packed_dim = - utils::to_outer_packed_dim(memory_layout); - const int32_t packed_dim_block_size = - utils::to_packed_dim_block_size(memory_layout, storage_type); - const int32_t outer_packed_dim_block_size = - utils::to_outer_packed_dim_block_size(memory_layout); - const bool is_block_transposed = - utils::is_block_transposed_layout(memory_layout); - - const int32_t block_numel = - packed_dim_block_size * outer_packed_dim_block_size; - if (storage_type != utils::kBuffer) { + const bool is_buffer = storage_type == utils::kBuffer; + + PackedDimInfo packed_dim_info(0, 1, 1, 1, 1, 1, false); + switch (memory_layout) { + case utils::kWidthPacked: + packed_dim_info = PackedDimInfo( + /*dim=*/0, + /*dim_block_size=*/is_buffer ? 1 : 4, + /*dim_align=*/is_buffer ? 1 : 4, + /*outer_dim=*/1, + /*outer_dim_block_size=*/1, + /*outer_dim_align=*/1, + /*is_block_transposed=*/false); + break; + case utils::kHeightPacked: + packed_dim_info = PackedDimInfo( + /*dim=*/1, + /*dim_block_size=*/is_buffer ? 1 : 4, + /*dim_align=*/is_buffer ? 1 : 4, + /*outer_dim=*/0, + /*outer_dim_block_size=*/1, + /*outer_dim_align=*/1, + /*is_block_transposed=*/false); + break; + case utils::kChannelsPacked: + packed_dim_info = PackedDimInfo( + /*dim=*/2, + /*dim_block_size=*/is_buffer ? 1 : 4, + /*dim_align=*/is_buffer ? 1 : 4, + /*outer_dim=*/0, + /*outer_dim_block_size=*/1, + /*outer_dim_align=*/1, + /*is_block_transposed=*/false); + break; + case utils::kPackedInt8_4W: + packed_dim_info = PackedDimInfo( + /*dim=*/0, + /*dim_block_size=*/is_buffer ? 4 : 16, + /*dim_align=*/is_buffer ? 4 : 16, + /*outer_dim=*/1, + /*outer_dim_block_size=*/1, + /*outer_dim_align=*/1, + /*is_block_transposed=*/false); + break; + case utils::kPackedInt8_4C: + packed_dim_info = PackedDimInfo( + /*dim=*/2, + /*dim_block_size=*/is_buffer ? 4 : 16, + /*dim_align=*/is_buffer ? 4 : 16, + /*outer_dim=*/0, + /*outer_dim_block_size=*/1, + /*outer_dim_align=*/1, + /*is_block_transposed=*/false); + break; + case utils::kPackedInt8_4W4C: + packed_dim_info = PackedDimInfo( + /*dim=*/2, + /*dim_block_size=*/4, + /*dim_align=*/4, + /*outer_dim=*/0, + /*outer_dim_block_size=*/4, + /*outer_dim_align=*/4, + /*is_block_transposed=*/false); + break; + case utils::kPackedInt8_4H4W: + packed_dim_info = PackedDimInfo( + /*dim=*/0, + /*dim_block_size=*/4, + /*dim_align=*/4, + /*outer_dim=*/1, + /*outer_dim_block_size=*/4, + /*outer_dim_align=*/4, + /*is_block_transposed=*/false); + break; + case utils::kPackedInt8_4C1W: + packed_dim_info = PackedDimInfo( + /*dim=*/2, + /*dim_block_size=*/is_buffer ? 4 : 16, + /*dim_align=*/is_buffer ? 4 : 16, + /*outer_dim=*/0, + /*outer_dim_block_size=*/1, + /*outer_dim_align=*/1, + /*is_block_transposed=*/true); + break; + default: + VK_THROW("Unknown GPUMemoryLayout"); + } + + if (!is_buffer) { + const int32_t block_numel = packed_dim_info.packed_dim_block_size * + packed_dim_info.outer_packed_dim_block_size; if (is_packed_int8_layout(memory_layout)) { VK_CHECK_COND(block_numel == 16); } else { @@ -53,12 +135,7 @@ PackedDimInfo calculate_packed_dim_info( } } - return PackedDimInfo( - packed_dim, - packed_dim_block_size, - outer_packed_dim, - outer_packed_dim_block_size, - is_block_transposed); + return packed_dim_info; } /* @@ -297,7 +374,8 @@ utils::ivec4 flip_and_unsqueeze_ivec4( * for GPU storage in the following ways: * * 1. The dimensionality of the tensor will be padded to a multiple of 4. - * 2. The size of the packed dimension will be padded to a multiple of 4. + * 2. The size of the packed dimension will be padded to a multiple of the + * packed dimension's alignment value. * * The "packed dimension" is determined based on the utils::GPUMemoryLayout * argument. @@ -317,23 +395,23 @@ std::vector calculate_padded_sizes( padded_sizes.at(i) = utils::val_at(i - ndim_up4, sizes); } - // Pad the packed dim to the block size - if (packed_dim_info.packed_dim_block_size > 1) { + // Pad the packed dim to the alignment + if (packed_dim_info.packed_dim_align > 1) { const int64_t dim_offset = packed_dim_info.packed_dim + 1; const int64_t padded_dim_size = utils::val_at(-dim_offset, sizes); padded_sizes.at(ndim_up4 - dim_offset) = utils::align_up( padded_dim_size, - static_cast(packed_dim_info.packed_dim_block_size)); + static_cast(packed_dim_info.packed_dim_align)); } - // Also pad the outer packed dimension if it's different from the inner packed - // dimension and is marked as padded. - if (packed_dim_info.outer_packed_dim_block_size > 1) { + // Also pad the outer packed dimension if it has alignment > 1. + if (packed_dim_info.outer_packed_dim_align > 1) { const int64_t outer_dim_offset = packed_dim_info.outer_packed_dim + 1; const int64_t outer_padded_dim_size = utils::val_at(-outer_dim_offset, sizes); - padded_sizes.at(ndim_up4 - outer_dim_offset) = - utils::align_up_4(outer_padded_dim_size); + padded_sizes.at(ndim_up4 - outer_dim_offset) = utils::align_up( + outer_padded_dim_size, + static_cast(packed_dim_info.outer_packed_dim_align)); } return padded_sizes; diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 8341097b490..301666f45c6 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -67,6 +67,12 @@ struct PackedDimInfo { // In physical memory, the size of the packed dim is aligned to this size to // ensure that data for the packed dim aligns with texel/block boundaries. int32_t packed_dim_block_size; + // In physical memory, the size of the packed dimension will be aligned to be + // a multiple of this value. This value must be a multiple of the packed_dim's + // block size, and is selected for performance reasons i.e. to ensure loads + // along the packed dim are aligned to cache lines, or to enable performance + // optimizations in shaders, i.e. remove the need for bounds checking. + int32_t packed_dim_align; // For block-packed layouts, represents the second tensor dimension that forms // the "width" dimension of the MxN square that is kept contiguous in memory. // For non block-packed layouts, represent the dimension with the next lowest @@ -77,6 +83,8 @@ struct PackedDimInfo { // 4H4W, represents the "height" of the square block that is kept contiguous // in memory. int32_t outer_packed_dim_block_size; + // See packed_dim_align + int32_t outer_packed_dim_align; // Typically the blocks of the tensor will be arranged such that the inner // dim of the block (i.e. the packed dim) has the lowest stride, and the // outer dim of the block (i.e. the outer packed dim) has the next lowest @@ -94,8 +102,10 @@ struct PackedDimInfo { PackedDimInfo( const int32_t dim, const int32_t dim_block_size, + const int32_t dim_align, const int32_t outer_dim, const int32_t outer_dim_block_size, + const int32_t outer_dim_align, const bool is_block_transposed); }; diff --git a/backends/vulkan/runtime/utils/StorageUtils.h b/backends/vulkan/runtime/utils/StorageUtils.h index d2997019a8b..d2978f1d662 100644 --- a/backends/vulkan/runtime/utils/StorageUtils.h +++ b/backends/vulkan/runtime/utils/StorageUtils.h @@ -139,104 +139,6 @@ static constexpr GPUMemoryLayout kPackedInt8_4H4W = static constexpr GPUMemoryLayout kPackedInt8_4C1W = GPUMemoryLayout::TENSOR_PACKED_INT8_4C1W; -template -T to_packed_dim(const GPUMemoryLayout layout) { - switch (layout) { - case kWidthPacked: - return 0; - case kHeightPacked: - return 1; - case kChannelsPacked: - return 2; - case kPackedInt8_4W: - return 0; - case kPackedInt8_4C: - return 2; - case kPackedInt8_4W4C: - return 2; - case kPackedInt8_4H4W: - return 0; - case kPackedInt8_4C1W: - return 2; - }; - // Should be unreachable - return 0; -} - -template -T to_outer_packed_dim(const GPUMemoryLayout layout) { - switch (layout) { - case kWidthPacked: - return 1; - case kHeightPacked: - return 0; - case kChannelsPacked: - return 0; - case kPackedInt8_4W: - return 1; - case kPackedInt8_4C: - return 0; - case kPackedInt8_4W4C: - return 0; - case kPackedInt8_4H4W: - return 1; - case kPackedInt8_4C1W: - return 0; - }; - // Should be unreachable - return 1; -} - -template -T to_packed_dim_block_size( - const GPUMemoryLayout layout, - const StorageType storage) { - switch (layout) { - case kWidthPacked: - return storage == kBuffer ? 1 : 4; - case kHeightPacked: - return storage == kBuffer ? 1 : 4; - case kChannelsPacked: - return storage == kBuffer ? 1 : 4; - case kPackedInt8_4W: - return storage == kBuffer ? 4 : 16; - case kPackedInt8_4C: - return storage == kBuffer ? 4 : 16; - case kPackedInt8_4W4C: - return 4; - case kPackedInt8_4H4W: - return 4; - case kPackedInt8_4C1W: - return storage == kBuffer ? 4 : 16; - }; - // Should be unreachable - return 1; -} - -template -T to_outer_packed_dim_block_size(const GPUMemoryLayout layout) { - switch (layout) { - case kWidthPacked: - return 1; - case kHeightPacked: - return 1; - case kChannelsPacked: - return 1; - case kPackedInt8_4W: - return 1; - case kPackedInt8_4C: - return 1; - case kPackedInt8_4W4C: - return 4; - case kPackedInt8_4H4W: - return 4; - case kPackedInt8_4C1W: - return 1; - }; - // Should be unreachable - return 1; -} - bool is_block_transposed_layout(const GPUMemoryLayout layout); bool is_packed_int8_layout(const GPUMemoryLayout layout); From 694f9b8cc35e0880711d40dbf0bb4744eb2d0d00 Mon Sep 17 00:00:00 2001 From: pytorchbot Date: Thu, 5 Feb 2026 15:57:03 -0800 Subject: [PATCH 3/3] [ET-VK][quantization] Implement layout-flexible quantize/dequantize operators (#17261) Implemented quantize_per_tensor and dequantize_per_tensor GLSL shaders and C++ dispatch logic to support the new single-dimension packed INT8 layouts (kPackedInt8_4W, kPackedInt8_4C, kPackedInt8_4H). These operators enable conversion between floating-point tensors and packed int8 representations with per-tensor scale and zero-point parameters. The implementation includes: - GLSL shaders: quantize_per_tensor and dequantize_per_tensor with support for both texture->buffer and buffer->buffer data flows, including GL_EXT_debug_printf statements for debugging - QuantizeDequantize.cpp: Added dispatch functions for the new layouts and registered etvk.q_dq_8bit_per_tensor.default operator - Test infrastructure: Created q_dq_8bit_per_tensor test binary with DEBUG_MODE support and reference CPU implementation for validation The shaders implement the quantization formula Q = clamp(round(x/scale) + zp, -128, 127) and dequantization formula x' = (Q - zp) * scale, with proper int8 packing/unpacking using little-endian byte ordering and sign extension. Differential Revision: [D92061370](https://our.internmc.facebook.com/intern/diff/D92061370/) [ghstack-poisoned] --- .github/workflows/pull.yml | 2 +- .../graph/ops/glsl/block_indexing.glslh | 280 ++++++++++++++ .../graph/ops/glsl/block_int8x4_load.glslh | 74 ++++ .../graph/ops/glsl/block_int8x4_store.glslh | 74 ++++ .../graph/ops/glsl/block_int8x4_utils.glslh | 109 ++++++ .../runtime/graph/ops/glsl/block_load.glslh | 105 ++++++ .../runtime/graph/ops/glsl/block_store.glslh | 97 +++++ .../runtime/graph/ops/glsl/common.glslh | 3 + .../runtime/graph/ops/glsl/indexing.glslh | 116 ++++++ .../graph/ops/glsl/q8ta_dequantize.glsl | 120 ++++++ .../graph/ops/glsl/q8ta_dequantize.yaml | 18 + .../runtime/graph/ops/glsl/q8ta_quantize.glsl | 118 ++++++ .../runtime/graph/ops/glsl/q8ta_quantize.yaml | 18 + .../vulkan/runtime/graph/ops/impl/Common.cpp | 281 ++++++++++++++ .../vulkan/runtime/graph/ops/impl/Common.h | 196 ++++++++++ .../graph/ops/impl/Q8taQuantizeDequantize.cpp | 156 ++++++++ .../graph/ops/impl/Q8taQuantizeDequantize.h | 33 ++ .../graph/ops/impl/QuantizeDequantize.cpp | 41 +-- .../graph/ops/impl/QuantizedConvolution.cpp | 6 +- .../vulkan/test/custom_ops/CMakeLists.txt | 2 +- .../impl/TestQ8taQuantizeDequantize.cpp | 71 ++++ .../custom_ops/qdq8ta_conv2d_activations.cpp | 255 ------------- backends/vulkan/test/custom_ops/targets.bzl | 2 +- .../vulkan/test/custom_ops/test_q8ta_qdq.cpp | 348 ++++++++++++++++++ 24 files changed, 2232 insertions(+), 293 deletions(-) create mode 100644 backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/block_int8x4_utils.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/block_load.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/block_store.glslh create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.yaml create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl create mode 100644 backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.yaml create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.cpp create mode 100644 backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.h create mode 100644 backends/vulkan/test/custom_ops/impl/TestQ8taQuantizeDequantize.cpp delete mode 100644 backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations.cpp create mode 100644 backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp diff --git a/.github/workflows/pull.yml b/.github/workflows/pull.yml index 2645be6478e..874b2a65168 100644 --- a/.github/workflows/pull.yml +++ b/.github/workflows/pull.yml @@ -1137,7 +1137,7 @@ jobs: ./cmake-out/backends/vulkan/test/custom_ops/q8csw_conv2d ./cmake-out/backends/vulkan/test/custom_ops/q4gsw_linear ./cmake-out/backends/vulkan/test/custom_ops/choose_qparams_per_row - ./cmake-out/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations + ./cmake-out/backends/vulkan/test/custom_ops/test_q8ta_qdq ./cmake-out/backends/vulkan/test/custom_ops/q8ta_q8ta_q8to_add # "Classic" Operator tests diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh new file mode 100644 index 00000000000..e7e64a601ef --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/block_indexing.glslh @@ -0,0 +1,280 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifndef BLOCK_GLSLH +#define BLOCK_GLSLH + +#include "block_int8x4_utils.glslh" +#include "common.glslh" + +// +// Block Layout Int Utils +// + +// These macros extract fields from the packed int returned by +// BlockConfig::as_packed_int(). See Common.h for the bit layout. +// +// Bit layout matches hashed layout format: +// bits 0- 3: block_dim_order[0] (inner_dim if not transposed, outer_dim if transposed) +// bits 4- 7: block_dim_order[1] (outer_dim if not transposed, inner_dim if transposed) +// bits 8-11: block_dim_order[2] (first_nonblock_dim) +// bits 12-15: block_dim_order[3] (second_nonblock_dim) +// bits 16-19: inner_dim +// bits 20-23: outer_dim +// bits 24-27: inner_dim_block_size +// bits 28-31: outer_dim_block_size + +// Extract block_dim_order elements (bits 0-15) +#define get_block_dim_order_0(x) ((x) & 0xF) +#define get_block_dim_order_1(x) (((x) >> 4) & 0xF) +#define get_block_dim_order_2(x) (((x) >> 8) & 0xF) +#define get_block_dim_order_3(x) (((x) >> 12) & 0xF) + +// Extract packed_dim_info (bits 16-31) +#define get_block_inner_dim(x) (((x) >> 16) & 0xF) +#define get_block_outer_dim(x) (((x) >> 20) & 0xF) +#define get_block_inner_dim_block_size(x) (((x) >> 24) & 0xF) +#define get_block_outer_dim_block_size(x) (((x) >> 28) & 0xF) + +/* + * Block-based programming utilities for compute shaders. + * + * A "block" is a 4x4 tile of tensor elements. The two dimensions of the block + * are called the "inner" dimension and the "outer" dimension. The inner dim is + * the one that is kept contiguous in memory (i.e. the packed dimension of the + * tensor), and the outer dim forms the other axis of the 2D block. + * + * For texel-packed tensors (single level of packing), a block is effectively a + * single texel repeated 4 times along the outer dimension. For block-packed + * tensors (two levels of packing), a block corresponds exactly to the 4x4 + * packed unit. + * + * When dispatching a block-based shader: + * - gl_GlobalInvocationID.x = block index along inner dimension + * - gl_GlobalInvocationID.y = block index along outer dimension + * - gl_GlobalInvocationID.z = plane index (remaining dimensions flattened) + */ + +// +// Index Conversion Utilities (TensorIndex4D versions) +// + +TensorIndex4D contiguous_block_idx_to_tensor4d_idx_with_block_config( + const BufferMetadata meta, + const uint block_idx, + const int block_config) { + TensorIndex4D tidx; + + uint block_strides[4]; + + uint stride = 1; + // Inner block dim + const int packed_dim_1 = get_block_inner_dim(block_config); + block_strides[packed_dim_1] = 1; + const uint block_size_1 = uint(get_block_inner_dim_block_size(block_config)); + stride = div_up(meta.sizes[0][packed_dim_1], block_size_1); + // Outer block dim + const int packed_dim_2 = get_block_outer_dim(block_config); + block_strides[packed_dim_2] = stride; + const uint block_size_2 = + uint(get_block_outer_dim_block_size(block_config)); + stride *= div_up(meta.sizes[0][packed_dim_2], block_size_2); + // First non-block dim + const int outer_dim_1 = get_block_dim_order_2(block_config); + block_strides[outer_dim_1] = stride; + stride *= meta.sizes[0][outer_dim_1]; + // Second non-block dim + const int outer_dim_2 = get_block_dim_order_3(block_config); + block_strides[outer_dim_2] = stride; + + uint contig_idx = block_idx; + // Second non-block dim + tidx.data[outer_dim_2] = int(contig_idx / block_strides[outer_dim_2]); + contig_idx %= block_strides[outer_dim_2]; + // First non-block dim (1; height) + tidx.data[outer_dim_1] = int(contig_idx / block_strides[outer_dim_1]); + contig_idx %= block_strides[outer_dim_1]; + // Outer block dim (0; width) + tidx.data[packed_dim_2] = + int(mul_4(contig_idx / block_strides[packed_dim_2])); + contig_idx %= block_strides[packed_dim_2]; + // Inner block dim (2; channels) + tidx.data[packed_dim_1] = int(mul_4(contig_idx)); + + return tidx; +} + +// +// TextureMetadata variants of block indexing +// + +TensorIndex4D contiguous_block_idx_to_tensor4d_idx_with_block_config( + const TextureMetadata meta, + const uint block_idx, + const int block_config) { + TensorIndex4D tidx; + + uint block_strides[4]; + + uint stride = 1; + // Inner block dim + const int packed_dim_1 = get_block_inner_dim(block_config); + block_strides[packed_dim_1] = 1; + const uint block_size_1 = uint(get_block_inner_dim_block_size(block_config)); + stride = div_up(meta.sizes[packed_dim_1], block_size_1); + // Outer block dim + const int packed_dim_2 = get_block_outer_dim(block_config); + block_strides[packed_dim_2] = stride; + const uint block_size_2 = + uint(get_block_outer_dim_block_size(block_config)); + stride *= div_up(meta.sizes[packed_dim_2], block_size_2); + // First non-block dim + const int outer_dim_1 = get_block_dim_order_2(block_config); + block_strides[outer_dim_1] = stride; + stride *= meta.sizes[outer_dim_1]; + // Second non-block dim + const int outer_dim_2 = get_block_dim_order_3(block_config); + block_strides[outer_dim_2] = stride; + + uint contig_idx = block_idx; + // Second non-block dim + tidx.data[outer_dim_2] = int(contig_idx / block_strides[outer_dim_2]); + contig_idx %= block_strides[outer_dim_2]; + // First non-block dim + tidx.data[outer_dim_1] = int(contig_idx / block_strides[outer_dim_1]); + contig_idx %= block_strides[outer_dim_1]; + // Outer block dim + tidx.data[packed_dim_2] = + int(mul_4(contig_idx / block_strides[packed_dim_2])); + contig_idx %= block_strides[packed_dim_2]; + // Inner block dim + tidx.data[packed_dim_1] = int(mul_4(contig_idx)); + + return tidx; +} + +// +// 3D Block Index Conversion Utilities (WHCN Dispatch) +// +// These functions convert a 3D thread index (gl_GlobalInvocationID) to a +// TensorIndex4D using a dispatch pattern: +// - thread_idx.x = W threads (divided by 4 if W is part of block) +// - thread_idx.y = H threads (divided by 4 if H is part of block) +// - thread_idx.z = C * N threads (C divided by 4 if C is part of block) +// +// Note: GLSL tensor metadata is in WHCN order (sizes[0]=W, sizes[1]=H, +// sizes[2]=C, sizes[3]=N), while C++ uses NCHW order. +// + +/* + * Convert a 3D block index to a TensorIndex4D using WHCN dispatch. + * + * Parameters: + * meta: BufferMetadata with tensor sizes in WHCN order + * thread_idx: 3D thread index (x=W, y=H, z=C*N) + * block_config: Packed block configuration from BlockConfig::as_packed_int() + * + * Returns: TensorIndex4D with logical tensor coordinates + */ +TensorIndex4D block_idx_3d_to_tensor4d_idx_with_block_config( + const BufferMetadata meta, + const uvec3 thread_idx, + const int block_config) { + TensorIndex4D tidx; + const int inner_dim = get_block_inner_dim(block_config); + const int outer_dim = get_block_outer_dim(block_config); + + // GLSL metadata is in WHCN order: sizes[0]=W, sizes[1]=H, sizes[2]=C, sizes[3]=N + + // Compute C threads for decomposing thread_idx.z + // C is blocked (divided by 4) only if it's part of the block + uint C_size = uint(meta.sizes[0][2]); + uint num_C; + if (inner_dim == 2 || outer_dim == 2) { + num_C = div_up_4(C_size); + } else { + num_C = C_size; + } + + // W (dim 0): blocked if inner or outer + if (inner_dim == 0 || outer_dim == 0) { + tidx.data[0] = int(thread_idx.x) * 4; // Block-aligned + } else { + tidx.data[0] = int(thread_idx.x); // Single value + } + + // H (dim 1): blocked if inner or outer + if (inner_dim == 1 || outer_dim == 1) { + tidx.data[1] = int(thread_idx.y) * 4; // Block-aligned + } else { + tidx.data[1] = int(thread_idx.y); // Single value + } + + // C (dim 2): blocked if inner or outer + if (inner_dim == 2 || outer_dim == 2) { + tidx.data[2] = int(thread_idx.z % num_C) * 4; // Block-aligned + } else { + tidx.data[2] = int(thread_idx.z % num_C); // Single value + } + + // N (dim 3): never blocked + tidx.data[3] = int(thread_idx.z / num_C); + + return tidx; +} + +/* + * Convert a 3D block index to a TensorIndex4D (TextureMetadata variant). + */ +TensorIndex4D block_idx_3d_to_tensor4d_idx_with_block_config( + const TextureMetadata meta, + const uvec3 thread_idx, + const int block_config) { + TensorIndex4D tidx; + const int inner_dim = get_block_inner_dim(block_config); + const int outer_dim = get_block_outer_dim(block_config); + + // GLSL metadata is in WHCN order: sizes[0]=W, sizes[1]=H, sizes[2]=C, sizes[3]=N + + // Compute C threads for decomposing thread_idx.z + uint C_size = uint(meta.sizes[2]); + uint num_C; + if (inner_dim == 2 || outer_dim == 2) { + num_C = div_up_4(C_size); + } else { + num_C = C_size; + } + + // W (dim 0): blocked if inner or outer + if (inner_dim == 0 || outer_dim == 0) { + tidx.data[0] = int(thread_idx.x) * 4; + } else { + tidx.data[0] = int(thread_idx.x); + } + + // H (dim 1): blocked if inner or outer + if (inner_dim == 1 || outer_dim == 1) { + tidx.data[1] = int(thread_idx.y) * 4; + } else { + tidx.data[1] = int(thread_idx.y); + } + + // C (dim 2): blocked if inner or outer + if (inner_dim == 2 || outer_dim == 2) { + tidx.data[2] = int(thread_idx.z % num_C) * 4; + } else { + tidx.data[2] = int(thread_idx.z % num_C); + } + + // N (dim 3): never blocked + tidx.data[3] = int(thread_idx.z / num_C); + + return tidx; +} + +#endif // BLOCK_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh new file mode 100644 index 00000000000..6ea636a0a17 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_load.glslh @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Macro to generate int8x4 block loading functions for a specific buffer. + * + * Usage: + * define_load_int8x4_buffer_fns(t_inp) + * + * This generates: + * - load_int8x4_block_from_t_inp(meta, tidx_base, layout, block_outer_dim) + * + * IMPORTANT: block_outer_dim must be such that the inner dimension (packed_dim) + * contains 4 contiguous int8 elements packed into one int32. If the loaded + * block needs to be transposed to match a different output layout, that + * transposition must be done by the caller. + */ + +#ifndef BLOCK_INT8X4_LOAD_GLSLH +#define BLOCK_INT8X4_LOAD_GLSLH + +#define define_load_int8x4_buffer_fns(buffer_name) \ + \ + ivec4 load_int8x4_block_from_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex4D tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim) { \ + const int outer_packed_dim = get_outer_packed_dim(hashed_layout); \ + const int outer_block_size = \ + get_outer_packed_dim_block_size(hashed_layout); \ + \ + /* Compute base packed index using block-based indexing */ \ + const uint block_idx = \ + tensor4d_idx_to_block_idx(meta, tidx_base, hashed_layout); \ + const uint texels_per_block = div_4(get_block_numel(hashed_layout)); \ + uint buf_idx = block_idx * texels_per_block; \ + \ + /* Fast path: contiguous texels when iterating along outer_packed_dim */ \ + if (outer_block_size == 4) { \ + if (block_outer_dim == outer_packed_dim) { \ + return ivec4( \ + buffer_name[buf_idx], \ + buffer_name[buf_idx + 1], \ + buffer_name[buf_idx + 2], \ + buffer_name[buf_idx + 3]); \ + } \ + else { \ + buf_idx += mod_4(tidx_base.data[outer_packed_dim]); \ + } \ + } \ + \ + /* General path: use stride for non-contiguous access */ \ + const uint outer_stride = \ + stride_at(meta, block_outer_dim) * texels_per_block; \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const int base_outer_idx = tidx_base.data[block_outer_dim]; \ + \ + ivec4 block = ivec4(0); \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + block[block_y] = buffer_name[buf_idx]; \ + } \ + buf_idx += outer_stride; \ + } \ + return block; \ + } + +#endif // BLOCK_INT8X4_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh new file mode 100644 index 00000000000..2a0e037c291 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_store.glslh @@ -0,0 +1,74 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Macro to generate int8x4 block storing functions for a specific buffer. + * + * Usage: + * define_store_int8x4_buffer_fns(t_out) + * + * This generates: + * - store_int8x4_block_to_t_out(meta, tidx_base, layout, block_outer_dim, + * block) + * + * IMPORTANT: block_outer_dim must be such that the inner dimension (packed_dim) + * contains 4 contiguous int8 elements packed into one int32. If the block needs + * to be transposed to match the output layout, that transposition must be done + * by the caller before storing. + */ + +#ifndef BLOCK_INT8X4_STORE_GLSLH +#define BLOCK_INT8X4_STORE_GLSLH + +#define define_store_int8x4_buffer_fns(buffer_name) \ + \ + void store_int8x4_block_to_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex4D tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim, \ + const ivec4 block) { \ + const int outer_packed_dim = get_outer_packed_dim(hashed_layout); \ + const int outer_block_size = \ + get_outer_packed_dim_block_size(hashed_layout); \ + \ + /* Compute base packed index using block-based indexing */ \ + const uint block_idx = \ + tensor4d_idx_to_block_idx(meta, tidx_base, hashed_layout); \ + const uint texels_per_block = div_4(get_block_numel(hashed_layout)); \ + uint buf_idx = block_idx * texels_per_block; \ + \ + /* Fast path: contiguous texels when iterating along outer_packed_dim */ \ + if (outer_block_size == 4) { \ + if (block_outer_dim == outer_packed_dim) { \ + buffer_name[buf_idx] = block[0]; \ + buffer_name[buf_idx + 1] = block[1]; \ + buffer_name[buf_idx + 2] = block[2]; \ + buffer_name[buf_idx + 3] = block[3]; \ + return; \ + } \ + else { \ + buf_idx += mod_4(tidx_base.data[outer_packed_dim]); \ + } \ + } \ + \ + /* General path: use stride for non-contiguous access */ \ + const uint outer_stride = \ + stride_at(meta, block_outer_dim) * texels_per_block; \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const int base_outer_idx = tidx_base.data[block_outer_dim]; \ + \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + buffer_name[buf_idx] = block[block_y]; \ + } \ + buf_idx += outer_stride; \ + } \ + } + +#endif // BLOCK_INT8X4_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_utils.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_utils.glslh new file mode 100644 index 00000000000..fea87428d0d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/block_int8x4_utils.glslh @@ -0,0 +1,109 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Utility functions for working with int8x4 blocks. + * + * An int8x4 block is a 4x4 tile of int8 values stored as an ivec4, + * where each ivec4 element contains 4 packed int8 values (one row of the + * block). + */ + +#ifndef BLOCK_INT8X4_UTILS_GLSLH +#define BLOCK_INT8X4_UTILS_GLSLH + +/* + * Write a single int8 value to a specific position in an int8x4 block. + * + * Parameters: + * block: The block to modify (ivec4 where each element has 4 packed int8) + * y: Row index (0-3) + * x: Column index (0-3) + * val: The int8 value to write + */ +void write_int8x4_block_element( + inout ivec4 block, + const int y, + const int x, + const int val) { + int texel = block[y]; + // Create a mask to clear the byte at position x, then insert val + int shift = x * 8; + int mask = ~(0xFF << shift); + block[y] = (texel & mask) | ((val & 0xFF) << shift); +} + +/* + * Transpose a 4x4 int8 block. + * + * Given block[y][x], produces result[x][y]. + * Each ivec4 element contains 4 packed int8 values. + */ +ivec4 transpose_int8x4_block(const ivec4 block) { + ivec4 result; + [[unroll]] for (int y = 0; y < 4; ++y) { + int packed = 0; + [[unroll]] for (int x = 0; x < 4; ++x) { + // Extract byte y from block[x] + int val = (block[x] >> (8 * y)) & 0xFF; + // Pack into position x + packed |= (val << (8 * x)); + } + result[y] = packed; + } + return result; +} + +// +// Debug print functions +// + +#ifdef DEBUG_MODE +/* + * Debug print function for Int8x4Block. + * + * Prints all 16 int8 values as a 4x4 matrix. + */ +void printInt8x4Block(const ivec4 block) { + // Unpack all 16 int8 values into an array + int v[16]; + [[unroll]] for (int i = 0; i < 4; ++i) { + int packed = block[i]; + v[i * 4 + 0] = (packed >> 0) & 0xFF; + v[i * 4 + 1] = (packed >> 8) & 0xFF; + v[i * 4 + 2] = (packed >> 16) & 0xFF; + v[i * 4 + 3] = (packed >> 24) & 0xFF; + } + // Sign extend from 8-bit to print as signed + [[unroll]] for (int i = 0; i < 16; ++i) { + if (v[i] > 127) + v[i] -= 256; + } + // Print as a 4x4 square in a single call to avoid interleaving + debugPrintfEXT( + "Int8x4Block:\\n [%d, %d, %d, %d]\\n [%d, %d, %d, %d]\\n [%d, %d, %d, %d]\\n [%d, %d, %d, %d]\\n", + v[0], + v[1], + v[2], + v[3], + v[4], + v[5], + v[6], + v[7], + v[8], + v[9], + v[10], + v[11], + v[12], + v[13], + v[14], + v[15]); +} +#endif // DEBUG_MODE + +#endif // BLOCK_INT8X4_UTILS_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh new file mode 100644 index 00000000000..d72a176aa0e --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/block_load.glslh @@ -0,0 +1,105 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Macros to generate block load functions for buffers and textures. + * + * Buffer usage: + * define_load_buffer_fns(t_inp) + * + * Texture usage: + * define_load_texture_fns(t_inp) + * + * Both generate functions with the same signature: + * - load_fp_block_from_t_inp(meta, tidx_base, layout, block_outer_dim) + * + * The block_inner_dim is derived from the packed_dim of the hashed_layout. + * If the loaded block needs to be transposed to match a different output + * layout, that transposition must be done by the caller. + * + * Parameters: + * buffer_name/texture_name: The name of the input buffer/texture (e.g., t_inp) + */ + +#ifndef BLOCK_LOAD_GLSLH +#define BLOCK_LOAD_GLSLH + +// +// Buffer load functions +// + +#define define_load_buffer_fns(buffer_name) \ + \ + mat4 load_fp_block_from_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex4D tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim) { \ + const int block_inner_dim = get_packed_dim(hashed_layout); \ + \ + /* Compute base buffer index once and use strides for iteration */ \ + const uint base_idx = \ + tensor4d_idx_to_buf_idx(meta, tidx_base, hashed_layout); \ + const uint outer_stride = stride_at(meta, block_outer_dim); \ + /* Inner stride is 1 since packed_dim == block_inner_dim */ \ + \ + /* Pre-compute bounds for efficient checking */ \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const uint inner_size = size_at(meta, block_inner_dim); \ + const int base_outer_idx = tidx_base.data[block_outer_dim]; \ + const int base_inner_idx = tidx_base.data[block_inner_dim]; \ + \ + mat4 block; \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + const uint row_idx = base_idx + block_y * outer_stride; \ + [[unroll]] for (int block_x = 0; block_x < 4; ++block_x) { \ + if (base_inner_idx + block_x < int(inner_size)) { \ + block[block_y][block_x] = float(buffer_name[row_idx + block_x]); \ + } else { \ + block[block_y][block_x] = 0.0; \ + } \ + } \ + } else { \ + block[block_y] = vec4(0.0); \ + } \ + } \ + return block; \ + } + +// +// Texture load functions +// + +#define define_load_texture_fns(texture_name) \ + \ + mat4 load_fp_block_from_##texture_name( \ + const TextureMetadata meta, \ + const TensorIndex4D tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim) { \ + /* Convert tensor index to texture position */ \ + /* Use tensor4d_idx_to_texel_pos_simple to properly map the packed dim */ \ + ivec3 tex_pos = tensor4d_idx_to_texel_pos_simple(meta, tidx_base); \ + const int tex_outer_dim = mod_4(block_outer_dim); \ + const int outer_size = meta.sizes[block_outer_dim]; \ + const int base_outer_idx = tidx_base.data[block_outer_dim]; \ + \ + mat4 block; \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < outer_size) { \ + block[block_y] = vec4(texelFetch(texture_name, tex_pos, 0)); \ + } else { \ + block[block_y] = vec4(0.0); \ + } \ + tex_pos[tex_outer_dim]++; \ + } \ + return block; \ + } + +#endif // BLOCK_LOAD_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh b/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh new file mode 100644 index 00000000000..66e9ab9fa2b --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/block_store.glslh @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/* + * Macros to generate block store functions for buffers and textures. + * + * Buffer usage: + * define_store_buffer_fns(t_outp, T) + * + * Texture usage: + * define_store_texture_fns(t_outp, VEC4_T) + * + * Both generate functions with the same signature: + * - store_fp_block_to_t_outp(meta, tidx_base, layout, block_outer_dim, block) + * + * The block_inner_dim is derived from the packed_dim of the hashed_layout. + * + * Parameters: + * buffer_name/texture_name: The name of the output buffer/texture (e.g., t_outp) + * scalar_type/vec4_type: The type for casting (e.g., float/vec4) + */ + +#ifndef BLOCK_STORE_GLSLH +#define BLOCK_STORE_GLSLH + +// +// Buffer store functions +// + +#define define_store_buffer_fns(buffer_name, scalar_type) \ + \ + void store_fp_block_to_##buffer_name( \ + const BufferMetadata meta, \ + const TensorIndex4D tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim, \ + const mat4 block) { \ + const int block_inner_dim = get_packed_dim(hashed_layout); \ + \ + /* Compute base buffer index once and use strides for iteration */ \ + const uint base_idx = \ + tensor4d_idx_to_buf_idx(meta, tidx_base, hashed_layout); \ + const uint outer_stride = stride_at(meta, block_outer_dim); \ + /* Inner stride is 1 since packed_dim == block_inner_dim */ \ + \ + /* Pre-compute bounds for efficient checking */ \ + const uint outer_size = size_at(meta, block_outer_dim); \ + const uint inner_size = size_at(meta, block_inner_dim); \ + const int base_outer_idx = tidx_base.data[block_outer_dim]; \ + const int base_inner_idx = tidx_base.data[block_inner_dim]; \ + \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < int(outer_size)) { \ + const uint row_idx = base_idx + block_y * outer_stride; \ + [[unroll]] for (int block_x = 0; block_x < 4; ++block_x) { \ + if (base_inner_idx + block_x < int(inner_size)) { \ + buffer_name[row_idx + block_x] = \ + scalar_type(block[block_y][block_x]); \ + } \ + } \ + } \ + } \ + } + +// +// Texture store functions +// + +#define define_store_texture_fns(texture_name, vec4_type) \ + \ + void store_fp_block_to_##texture_name( \ + const TextureMetadata meta, \ + const TensorIndex4D tidx_base, \ + const int hashed_layout, \ + const int block_outer_dim, \ + const mat4 block) { \ + /* Convert tensor index to texture position */ \ + /* Use tensor4d_idx_to_texel_pos_simple to properly map the packed dim */ \ + ivec3 tex_pos = tensor4d_idx_to_texel_pos_simple(meta, tidx_base); \ + const int tex_outer_dim = mod_4(block_outer_dim); \ + const int outer_size = meta.sizes[block_outer_dim]; \ + const int base_outer_idx = tidx_base.data[block_outer_dim]; \ + \ + [[unroll]] for (int block_y = 0; block_y < 4; ++block_y) { \ + if (base_outer_idx + block_y < outer_size) { \ + imageStore(texture_name, tex_pos, vec4_type(block[block_y])); \ + } \ + tex_pos[tex_outer_dim]++; \ + } \ + } + +#endif // BLOCK_STORE_GLSLH diff --git a/backends/vulkan/runtime/graph/ops/glsl/common.glslh b/backends/vulkan/runtime/graph/ops/glsl/common.glslh index 9ade64910f2..3b2010c7963 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/common.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/common.glslh @@ -25,6 +25,8 @@ #define div_up_4(x) (((x) + 3) >> 2) #define div_up_8(x) (((x) + 7) >> 3) +#define div_up(x, y) (((x) + (y) - 1) / (y)) + #define align_up_2(x) ((x + 1) & -2) #define align_up_4(x) ((x + 3) & -4) #define align_up_8(x) ((x + 7) & -8) @@ -33,6 +35,7 @@ #define mod_4(x) ((x) & 3) #define mod_8(x) ((x) & 7) + int sign_extend_8bit(const int val) { if ((val & 0x80) != 0) { return val | (~0xFF); diff --git a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh index deb6f0a9e30..24f050b694e 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh +++ b/backends/vulkan/runtime/graph/ops/glsl/indexing.glslh @@ -46,6 +46,7 @@ int extract_4b(const int packed, const int pos) { return (packed >> (pos * 4)) & 0xF; } +#define extract_buffer_packed_dim(layout) int(layout & 0xF) // Corresponds to dim_order[:4] = [0, 1, 2, 3] #define CONTIGUOUS_BUFFER_LAYOUT_ID 0x3210 @@ -69,6 +70,13 @@ bool is_channels_last(const int hashed_layout) { return layout_id(hashed_layout) == CHANNELS_LAST_BUFFER_LAYOUT_ID; } +// Extract packed dim info from hashed_layout (bits 16-31) +// These match the format created by create_hashed_layout() in Tensor.cpp +#define get_packed_dim(layout) (((layout) >> 16) & 0xF) +#define get_outer_packed_dim(layout) (((layout) >> 20) & 0xF) +#define get_packed_dim_block_size(layout) (((layout) >> 24) & 0xF) +#define get_outer_packed_dim_block_size(layout) (((layout) >> 28) & 0xF) + // // BufferMetadata // @@ -185,6 +193,11 @@ uint x(const TensorIndex tidx) { return tidx.data[0][0]; } +bool out_of_bounds(const TensorIndex tidx, const BufferMetadata meta) { + return any(greaterThanEqual(tidx.data[0], meta.sizes[0])) || + any(greaterThanEqual(tidx.data[1], meta.sizes[1])); +} + // // TensorIndex4D (useful for texture backed tensors) // @@ -406,6 +419,109 @@ uint tensor4d_idx_to_linear_idx( return lin_idx; } +// +// Block-packed tensor indexing +// + +/* + * Get the number of elements per block from hashed_layout. + */ +int get_block_numel(const int hashed_layout) { + const int inner_block_size = get_packed_dim_block_size(hashed_layout); + const int outer_block_size = get_outer_packed_dim_block_size(hashed_layout); + return inner_block_size * outer_block_size; +} + +/* + * Compute the intra-block index (position within a block). + * + * Within a block, elements are stored with inner dimension varying fastest: + * intra_block_idx = outer_offset * inner_block_size + inner_offset + * + * Parameters: + * tidx: TensorIndex4D with logical tensor coordinates + * hashed_layout: Packed layout info + * + * Returns: Intra-block index (0 to block_numel-1) + */ +int tensor4d_idx_to_intra_block_idx( + const TensorIndex4D tidx, + const int hashed_layout) { + const int inner_dim = get_packed_dim(hashed_layout); + const int outer_dim = get_outer_packed_dim(hashed_layout); + const int inner_block_size = get_packed_dim_block_size(hashed_layout); + const int outer_block_size = get_outer_packed_dim_block_size(hashed_layout); + + const int inner_offset = tidx.data[inner_dim] % inner_block_size; + const int outer_offset = tidx.data[outer_dim] % outer_block_size; + + return outer_offset * inner_block_size + inner_offset; +} + +/* + * Convert a tensor index to a block-space linear index. + * + * The tensor index is converted to block-space coordinates by dividing + * the packed dimensions by their block sizes, then the linear index is + * computed using the block-space strides from BufferMetadata. + * + * Parameters: + * meta: BufferMetadata with block-space strides + * tidx: TensorIndex4D with logical tensor coordinates + * hashed_layout: Packed layout info + * + * Returns: Linear index in block space + */ +int tensor4d_idx_to_block_idx( + const BufferMetadata meta, + TensorIndex4D tidx, + const int hashed_layout) { + // Extract packed dim info + const int inner_dim = get_packed_dim(hashed_layout); + const int outer_dim = get_outer_packed_dim(hashed_layout); + const int inner_block_size = get_packed_dim_block_size(hashed_layout); + const int outer_block_size = get_outer_packed_dim_block_size(hashed_layout); + + // Convert to block-space coordinates + tidx.data[inner_dim] = tidx.data[inner_dim] / inner_block_size; + tidx.data[outer_dim] = tidx.data[outer_dim] / outer_block_size; + + // Compute block-space linear index + int block_idx = 0; + [[unroll]] for (int d = 0; d < 4; ++d) { + block_idx += int(meta.strides[0][d]) * tidx.data[d]; + } + return block_idx; +} + +/* + * Convert a tensor index to a linear buffer index for block-packed layouts. + * + * For block-packed tensors: + * - Elements are grouped into blocks of (inner_block_size × outer_block_size) + * - Strides in BufferMetadata are in "block space" + * - Final index = block_index × block_numel + intra_block_index + * + * Parameters: + * meta: BufferMetadata containing sizes and block-space strides + * tidx: TensorIndex4D with logical tensor coordinates + * hashed_layout: Packed layout info from create_hashed_layout() (see + * indexing.glslh) + * + * Returns: Linear buffer index for the element + */ +int tensor4d_idx_to_buf_idx( + const BufferMetadata meta, + const TensorIndex4D tidx, + const int hashed_layout) { + const int block_idx = tensor4d_idx_to_block_idx(meta, tidx, hashed_layout); + const int intra_block_idx = + tensor4d_idx_to_intra_block_idx(tidx, hashed_layout); + const int block_numel = get_block_numel(hashed_layout); + + return block_idx * block_numel + intra_block_idx; +} + // // Debug utilities // diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl new file mode 100644 index 00000000000..6989dc2d87d --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.glsl @@ -0,0 +1,120 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions(OUTPUT_STORAGE, DTYPE)} + +#define PRECISION ${PRECISION} +#define T ${buffer_scalar_type(DTYPE)} +$if OUTPUT_STORAGE == "texture3d": + #define VEC4_T ${texel_load_type(DTYPE, "texture3d")} + +$if OUTPUT_STORAGE == "buffer": + ${define_active_storage_type("buffer")} +$else: + ${define_active_storage_type("texture3d")} + +layout(std430) buffer; + +#include "indexing.glslh" + +// Output: dequantized floating point values (buffer or texture) +${layout_declare_tensor(B, "w", "t_outp", DTYPE, OUTPUT_STORAGE)} +// Input buffer: quantized int32 values (each int32 contains 4 packed int8s) +${layout_declare_tensor(B, "r", "t_inp", "int", "buffer")} + +// Metadata for output tensor (floating point) - buffer or texture +$if OUTPUT_STORAGE == "buffer": + ${layout_declare_ubo(B, "BufferMetadata", "outp")} +$else: + ${layout_declare_ubo(B, "TextureMetadata", "outp")} +// Metadata for input tensor (quantized int8x4) - always buffer +${layout_declare_ubo(B, "BufferMetadata", "inp")} + +layout(push_constant) uniform restrict Block { + float scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "outp_block_config", "0")} +${layout_declare_spec_const(C, "int", "inp_block_config", "0")} + +#include "block_indexing.glslh" +#include "block_int8x4_load.glslh" +#include "block_store.glslh" + +// Generate loading functions for t_inp buffer +define_load_int8x4_buffer_fns(t_inp) +// Generate storing functions for t_outp +$if OUTPUT_STORAGE == "buffer": + define_store_buffer_fns(t_outp, T) +$else: + define_store_texture_fns(t_outp, VEC4_T) + +mat4 dequantize_int8x4_block( + const ivec4 block, const float scale, const int zp) { + mat4 result; + [[unroll]] for (int i = 0; i < 4; ++i) { + // Unpack 4 int8 values from packed int32 + int packed = block[i]; + ivec4 unpacked = ivec4( + (packed >> 0) & 0xFF, + (packed >> 8) & 0xFF, + (packed >> 16) & 0xFF, + (packed >> 24) & 0xFF); + // Sign extend from 8-bit + unpacked = (unpacked ^ 0x80) - 0x80; + // Dequantize: (q - zp) * scale + result[i] = vec4(unpacked - zp) * scale; + } + return result; +} + +void main() { + TensorIndex4D tidx; + +#ifdef USING_BUFFER + // Buffer storage: use linear dispatch + const uint contig_block_idx = gl_GlobalInvocationID.x; + tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + inp, contig_block_idx, inp_block_config); +#else + // Texture storage: use 3D extents dispatch + const uvec3 thread_idx = gl_GlobalInvocationID; + tidx = block_idx_3d_to_tensor4d_idx_with_block_config( + inp, thread_idx, inp_block_config); +#endif + + if (out_of_bounds(tidx, inp)) { + return; + } + + // Load int8 block from input using the thread's block index + const int inp_block_outer_dim = get_block_outer_dim(inp_block_config); + ivec4 int8_block = load_int8x4_block_from_t_inp( + inp, tidx, inp_layout, inp_block_outer_dim); + + // If input and output have different block configs (different packed dims), + // transpose the block to match output's layout + if (inp_block_config != outp_block_config) { + int8_block = transpose_int8x4_block(int8_block); + } + + // Dequantize the int8 block to float values + mat4 fp_block = dequantize_int8x4_block(int8_block, scale, zp); + + // Store dequantized values to output buffer using output's block config + const int outp_block_outer_dim = get_block_outer_dim(outp_block_config); + store_fp_block_to_t_outp( + outp, tidx, outp_layout, outp_block_outer_dim, fp_block); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.yaml new file mode 100644 index 00000000000..d3a531d6385 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_dequantize.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_dequantize: + parameter_names_with_default_values: + DTYPE: float + OUTPUT_STORAGE: buffer + generate_variant_forall: + OUTPUT_STORAGE: + - VALUE: buffer + - VALUE: texture3d + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_dequantize diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl b/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl new file mode 100644 index 00000000000..c0c2a3a914a --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.glsl @@ -0,0 +1,118 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#version 450 core + +${define_required_extensions(INPUT_STORAGE, DTYPE)} + +#define PRECISION ${PRECISION} +#define T ${buffer_scalar_type(DTYPE)} +$if INPUT_STORAGE == "texture3d": + #define VEC4_T ${texel_load_type(DTYPE, "texture3d")} + +$if INPUT_STORAGE == "buffer": + ${define_active_storage_type("buffer")} +$else: + ${define_active_storage_type("texture3d")} + +layout(std430) buffer; + +#include "indexing.glslh" + +// Output buffer: quantized int32 values (each int32 contains 4 packed int8s) +${layout_declare_tensor(B, "w", "t_outp", "int", "buffer")} +// Input: floating point values (buffer or texture) +${layout_declare_tensor(B, "r", "t_inp", DTYPE, INPUT_STORAGE)} + +// Metadata for output tensor (quantized int8x4) - always buffer +${layout_declare_ubo(B, "BufferMetadata", "outp")} +// Metadata for input tensor (floating point) - buffer or texture +$if INPUT_STORAGE == "buffer": + ${layout_declare_ubo(B, "BufferMetadata", "inp")} +$else: + ${layout_declare_ubo(B, "TextureMetadata", "inp")} + +layout(push_constant) uniform restrict Block { + float inv_scale; + int zp; +}; + +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; + +${layout_declare_spec_const(C, "int", "inp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "outp_layout", "CONTIG_LAYOUT_INT")} +${layout_declare_spec_const(C, "int", "inp_block_config", "0")} +${layout_declare_spec_const(C, "int", "outp_block_config", "0")} + +#include "block_indexing.glslh" +#include "block_load.glslh" +$if INPUT_STORAGE == "buffer": + // Generate loading functions for t_inp buffer + define_load_buffer_fns(t_inp) +$else: + // Generate loading functions for t_inp texture + define_load_texture_fns(t_inp) +#include "block_int8x4_store.glslh" + +// Generate storing functions for t_outp buffer +define_store_int8x4_buffer_fns(t_outp) + +ivec4 quantize_fp_block( + const mat4 block, const float inv_scale, const int zp) { + ivec4 result; + [[unroll]] for (int i = 0; i < 4; ++i) { + // Quantize: round(val * inv_scale) + zp, clamped to [-128, 127] + ivec4 quantized = ivec4(round(block[i] * inv_scale)) + zp; + quantized = clamp(quantized, -128, 127); + // Pack 4 int8 values into one int32 + result[i] = ((quantized[0] & 0xFF) << 0) | + ((quantized[1] & 0xFF) << 8) | + ((quantized[2] & 0xFF) << 16) | + ((quantized[3] & 0xFF) << 24); + } + return result; +} + +void main() { + TensorIndex4D tidx; + +#ifdef USING_BUFFER + // Buffer storage: use linear dispatch + const uint contig_block_idx = gl_GlobalInvocationID.x; + tidx = contiguous_block_idx_to_tensor4d_idx_with_block_config( + inp, contig_block_idx, inp_block_config); +#else + // Texture storage: use 3D extents dispatch + const uvec3 thread_idx = gl_GlobalInvocationID; + tidx = block_idx_3d_to_tensor4d_idx_with_block_config( + inp, thread_idx, inp_block_config); +#endif + + if (out_of_bounds(tidx, inp)) { + return; + } + + // Load FP block from input using the thread's block index + const int inp_block_outer_dim = get_block_outer_dim(inp_block_config); + mat4 fp_block = load_fp_block_from_t_inp( + inp, tidx, inp_layout, inp_block_outer_dim); + + // If input and output have different block configs (different packed dims), + // transpose the block to match output's layout + if (inp_block_config != outp_block_config) { + fp_block = transpose(fp_block); + } + + // Quantize the float block to int8 values + ivec4 int8_block = quantize_fp_block(fp_block, inv_scale, zp); + + // Store quantized values to output buffer using output's block config + const int outp_block_outer_dim = get_block_outer_dim(outp_block_config); + store_int8x4_block_to_t_outp( + outp, tidx, outp_layout, outp_block_outer_dim, int8_block); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.yaml b/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.yaml new file mode 100644 index 00000000000..b3124fb6d75 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/glsl/q8ta_quantize.yaml @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +q8ta_quantize: + parameter_names_with_default_values: + DTYPE: float + INPUT_STORAGE: buffer + generate_variant_forall: + INPUT_STORAGE: + - VALUE: buffer + - VALUE: texture3d + DTYPE: + - VALUE: float + shader_variants: + - NAME: q8ta_quantize diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.cpp b/backends/vulkan/runtime/graph/ops/impl/Common.cpp index 71690ffc604..0286889de5c 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Common.cpp @@ -6,10 +6,93 @@ * LICENSE file in the root directory of this source tree. */ +#include #include namespace vkcompute { +// +// BlockConfig implementation +// + +BlockConfig::BlockConfig( + int32_t inner, + int32_t inner_block_size, + int32_t outer, + int32_t outer_block_size, + bool transposed) + : inner_dim(inner), + inner_dim_block_size(inner_block_size), + outer_dim(outer), + outer_dim_block_size(outer_block_size), + block_transposed(transposed), + block_dim_order{0, 0, 0, 0} { + // Block dims must be different + VK_CHECK_COND(outer_dim != inner_dim); + + // Find the two lowest dim indices that are not inner_dim or outer_dim + int32_t first_nonblock_dim = -1; + int32_t second_nonblock_dim = -1; + int32_t other_idx = 0; + for (int32_t d = 0; other_idx < 2; ++d) { + if (d != inner_dim && d != outer_dim) { + if (other_idx == 0) { + first_nonblock_dim = d; + } else { + second_nonblock_dim = d; + } + ++other_idx; + } + } + + // Set block_dim_order based on block_transposed + if (block_transposed) { + // Transposed: {outer_dim, inner_dim, first_nonblock_dim, + // second_nonblock_dim} + block_dim_order[0] = outer_dim; + block_dim_order[1] = inner_dim; + } else { + // Normal: {inner_dim, outer_dim, first_nonblock_dim, second_nonblock_dim} + block_dim_order[0] = inner_dim; + block_dim_order[1] = outer_dim; + } + block_dim_order[2] = first_nonblock_dim; + block_dim_order[3] = second_nonblock_dim; + + // Validate all dims are in valid range [0, 3] + for (int i = 0; i < 4; ++i) { + VK_CHECK_COND(block_dim_order[i] >= 0 && block_dim_order[i] < 4); + } +} + +int32_t BlockConfig::as_packed_int() const { + int32_t packed = 0; + // Pack block_dim_order in bits 0-15 (matches hashed layout format) + packed |= (block_dim_order[0] & 0xF); // bits 0-3 + packed |= (block_dim_order[1] & 0xF) << 4; // bits 4-7 + packed |= (block_dim_order[2] & 0xF) << 8; // bits 8-11 + packed |= (block_dim_order[3] & 0xF) << 12; // bits 12-15 + // Pack packed_dim_info in bits 16-31 (matches hashed layout format) + packed |= (inner_dim & 0xF) << 16; // bits 16-19 + packed |= (outer_dim & 0xF) << 20; // bits 20-23 + packed |= (inner_dim_block_size & 0xF) << 24; // bits 24-27 + packed |= (outer_dim_block_size & 0xF) << 28; // bits 28-31 + + return packed; +} + +int32_t BlockConfig::inner_dim_from_packed_int(int32_t packed_int) { + return (packed_int >> 16) & 0xF; // bits 16-19 +} + +int32_t BlockConfig::outer_dim_from_packed_int(int32_t packed_int) { + return (packed_int >> 20) & 0xF; // bits 20-23 +} + +// +// Default workgroup size functions +// + utils::uvec3 default_pick_global_wg_size( ComputeGraph* graph, const vkapi::ShaderInfo& shader, @@ -79,4 +162,202 @@ utils::uvec3 pick_wc_square_wg_size( return {16u, 1u, 4u}; } +BlockConfig create_block_config_from_io_packed_dims( + ComputeGraph& graph, + const ValueRef output, + const ValueRef input) { + const int32_t block_inner_dim = graph.packed_dim_of(output); + int32_t block_outer_dim = graph.packed_dim_of(input); + + // If inner and outer dims are the same, pick a different outer dim + if (block_outer_dim == block_inner_dim) { + if (block_inner_dim == 0) { + block_outer_dim = 1; + } else { + block_outer_dim = 0; + } + } + + // Create a BlockConfig with block sizes of 4 for both dimensions + return BlockConfig{block_inner_dim, 4, block_outer_dim, 4}; +} + +BlockConfig create_block_config_for_tensor( + ComputeGraph& graph, + const ValueRef tensor) { + const int32_t packed_dim = graph.packed_dim_of(tensor); + + // Pick an outer dimension that differs from the packed dimension + const int32_t outer_dim = (packed_dim == 0) ? 1 : 0; + + // Create a BlockConfig with block sizes of 4 for both dimensions + return BlockConfig{packed_dim, 4, outer_dim, 4}; +} + +BlockConfig create_block_config_from_other( + ComputeGraph& graph, + const ValueRef tensor, + const BlockConfig& other) { + const int32_t packed_dim = graph.packed_dim_of(tensor); + + // If tensor's packed dim matches other's inner dim, use same config + if (packed_dim == other.inner_dim) { + return other; + } + + // Otherwise, transpose: swap inner and outer dimensions + return BlockConfig{ + other.outer_dim, + other.outer_dim_block_size, + other.inner_dim, + other.inner_dim_block_size}; +} + +utils::uvec3 pick_linear_global_wg_with_block_config( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& extra_args) { + (void)shader; + + const ValueRef output = args.at(0).refs.at(0); + // extra_args contains the packed block config directly as a ValueRef + // (int32_t) + const int32_t packed_block_config = static_cast(extra_args.at(0)); + + // Extract block configuration from packed integer + const int32_t inner_dim = + BlockConfig::inner_dim_from_packed_int(packed_block_config); + const int32_t outer_dim = + BlockConfig::outer_dim_from_packed_int(packed_block_config); + + const std::vector& sizes = graph->sizes_of(output); + const size_t ndim = sizes.size(); + + // Compute number of blocks along inner and outer dimensions + const int64_t inner_size = sizes[ndim - 1 - inner_dim]; + const int64_t outer_size = sizes[ndim - 1 - outer_dim]; + + const uint32_t num_inner_blocks = + utils::safe_downcast(utils::div_up(inner_size, int64_t(4))); + const uint32_t num_outer_blocks = + utils::safe_downcast(utils::div_up(outer_size, int64_t(4))); + + // Compute number of planes (product of dimensions not in the block) + uint32_t num_planes = 1; + for (size_t i = 0; i < ndim; ++i) { + const int32_t whcn_dim = ndim - 1 - i; + if (whcn_dim != inner_dim && whcn_dim != outer_dim) { + num_planes *= utils::safe_downcast(sizes[i]); + } + } + + // Return linear workgroup size: {total_blocks, 1u, 1u} + const uint32_t total_blocks = + num_inner_blocks * num_outer_blocks * num_planes; + return {total_blocks, 1u, 1u}; +} + +utils::uvec3 pick_extents_global_wg_with_block_config( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& extra_args) { + (void)shader; + + const ValueRef output = args.at(0).refs.at(0); + // extra_args contains the packed block config directly as a ValueRef + // (int32_t) + const int32_t packed_block_config = static_cast(extra_args.at(0)); + + // Extract block configuration from packed integer + // Note: inner_dim and outer_dim use WHCN order (0=W, 1=H, 2=C, 3=N) + const int32_t inner_dim = + BlockConfig::inner_dim_from_packed_int(packed_block_config); + const int32_t outer_dim = + BlockConfig::outer_dim_from_packed_int(packed_block_config); + + const std::vector& sizes = graph->sizes_of(output); + + // C++ sizes are in NCHW order: sizes[0]=N, sizes[1]=C, sizes[2]=H, sizes[3]=W + // Access dimensions from the end for tensors with fewer than 4 dims + const int64_t W = utils::val_at(-1, sizes); + const int64_t H = utils::val_at(-2, sizes); + const int64_t C = utils::val_at(-3, sizes); + const int64_t N = utils::val_at(-4, sizes); + + // Dispatch structure: {x_threads, y_threads, z_threads} + // - x corresponds to W dimension + // - y corresponds to H dimension + // - z corresponds to C * N (combined) + // + // Block dimensions (inner_dim and outer_dim) are divided by 4, + // non-block dimensions are not divided. + + uint32_t x_threads, y_threads; + int64_t C_for_z; + + // X dimension (W, WHCN dim 0) + if (inner_dim == 0 || outer_dim == 0) { + x_threads = utils::safe_downcast(utils::div_up(W, int64_t(4))); + } else { + x_threads = utils::safe_downcast(W); + } + + // Y dimension (H, WHCN dim 1) + if (inner_dim == 1 || outer_dim == 1) { + y_threads = utils::safe_downcast(utils::div_up(H, int64_t(4))); + } else { + y_threads = utils::safe_downcast(H); + } + + // Z dimension: C * N where C is blocked if it's part of the block + if (inner_dim == 2 || outer_dim == 2) { + C_for_z = utils::div_up(C, int64_t(4)); + } else { + C_for_z = C; + } + const uint32_t z_threads = utils::safe_downcast(C_for_z * N); + + return {x_threads, y_threads, z_threads}; +} + +utils::uvec3 pick_square_local_wg_with_block_config( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& extra_args) { + (void)graph; + (void)shader; + (void)args; + + // Detect linear dispatch pattern: global_wg = {total_blocks, 1, 1} + if (global_workgroup_size[1u] == 1u && global_workgroup_size[2u] == 1u) { + return {64u, 1u, 1u}; + } + + // Extents dispatch: use 8x8 square on inner_dim and outer_dim axes + // extra_args contains the packed block config as a ValueRef (int32_t) + const int32_t packed_block_config = static_cast(extra_args.at(0)); + + // Extract block configuration from packed integer + // inner_dim and outer_dim use WHCN order (0=W, 1=H, 2=C, 3=N) + const int32_t inner_dim = + BlockConfig::inner_dim_from_packed_int(packed_block_config); + const int32_t outer_dim = + BlockConfig::outer_dim_from_packed_int(packed_block_config); + + // Build local workgroup size: + // - x corresponds to W (WHCN dim 0) + // - y corresponds to H (WHCN dim 1) + // - z corresponds to C*N (WHCN dim 2 for C) + // Set axes in the block (inner_dim, outer_dim) to 8, others to 1 + uint32_t local_x = (inner_dim == 0 || outer_dim == 0) ? 8u : 1u; + uint32_t local_y = (inner_dim == 1 || outer_dim == 1) ? 8u : 1u; + uint32_t local_z = (inner_dim == 2 || outer_dim == 2) ? 8u : 1u; + + return {local_x, local_y, local_z}; +} + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Common.h b/backends/vulkan/runtime/graph/ops/impl/Common.h index b412f737c13..84cacc8e4f7 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Common.h +++ b/backends/vulkan/runtime/graph/ops/impl/Common.h @@ -13,6 +13,73 @@ namespace vkcompute { +/** + * BlockConfig describes how a tensor is partitioned into blocks for the purpose + * of thread mapping in GPU compute shaders. Each thread processes one block + * of elements. + * + * This is distinct from PackedDimInfo in Tensor.h which describes memory + * layout. BlockConfig is used solely for operator implementations to define 4x4 + * block partitioning schemes. + * + * The block configuration has two dimensions: + * - inner_dim: The dimension where 4 consecutive elements are processed + * together within a single thread + * - outer_dim: A second dimension where 4 elements are grouped, resulting + * in a 4x4 block of 16 elements per thread + */ +struct BlockConfig { + // The inner block dimension (WHCN index: 0=W, 1=H, 2=C, 3=N) + // 4 consecutive elements along this dimension form the inner part of a block + int32_t inner_dim; + // Block size along the inner dimension (typically 4) + int32_t inner_dim_block_size; + // The outer block dimension (WHCN index: 0=W, 1=H, 2=C, 3=N) + // 4 elements along this dimension form the outer part of a block + int32_t outer_dim; + // Block size along the outer dimension (typically 4) + int32_t outer_dim_block_size; + // Whether the block is transposed (swaps stride ordering of inner/outer dim) + bool block_transposed; + // Dimension order for the block: + // - If block_transposed = false: {inner_dim, outer_dim, first_nonblock_dim, + // second_nonblock_dim} + // - If block_transposed = true: {outer_dim, inner_dim, first_nonblock_dim, + // second_nonblock_dim} + int32_t block_dim_order[4]; + + BlockConfig( + int32_t inner, + int32_t inner_block_size, + int32_t outer, + int32_t outer_block_size, + bool transposed = false); + + /** + * Returns a packed int32_t encoding the block configuration. The structure + * matches the hashed layout int format used in shaders: + * bits 0- 3: block_dim_order[0] + * bits 4- 7: block_dim_order[1] + * bits 8-11: block_dim_order[2] + * bits 12-15: block_dim_order[3] + * bits 16-19: inner_dim + * bits 20-23: outer_dim + * bits 24-27: inner_dim_block_size + * bits 28-31: outer_dim_block_size + */ + int32_t as_packed_int() const; + + /** + * Extracts inner_dim from a packed int32_t representation. + */ + static int32_t inner_dim_from_packed_int(int32_t packed_int); + + /** + * Extracts outer_dim from a packed int32_t representation. + */ + static int32_t outer_dim_from_packed_int(int32_t packed_int); +}; + /** * Creates a global workgroup size based on the first output tensor in the args. * This is a utility function that extracts the output tensor from @@ -61,4 +128,133 @@ utils::uvec3 pick_wc_square_wg_size( const std::vector& args, const std::vector& resize_args); +/** + * Creates a BlockConfig based on the packed dimensions of an output and input + * tensor pair. This is useful for operations like dequantize where the block + * configuration depends on both tensors. + * + * The inner dimension is determined by the output tensor's packed dimension, + * and the outer dimension is determined by the input tensor's packed dimension. + * If they are the same, the outer dimension is adjusted to avoid conflict. + * + * @param graph The compute graph + * @param output The output tensor reference + * @param input The input tensor reference + * @return A BlockConfig configured for block-based operations + */ +BlockConfig create_block_config_from_io_packed_dims( + ComputeGraph& graph, + const ValueRef output, + const ValueRef input); + +/** + * Creates a BlockConfig based on the packed dimension of a single tensor. + * This is useful when you need separate block configs for input and output + * tensors. + * + * The inner dimension is determined by the tensor's packed dimension. + * The outer dimension is set to an adjacent dimension that differs from + * the packed dimension. + * + * @param graph The compute graph + * @param tensor The tensor reference + * @return A BlockConfig configured for block-based operations + */ +BlockConfig create_block_config_for_tensor( + ComputeGraph& graph, + const ValueRef tensor); + +/** + * Creates a BlockConfig for a tensor based on another block config, ensuring + * the inner dimension matches the tensor's packed dimension. + * + * This is useful when you need block configs for both input and output tensors + * that share the same block axes but may need to be transposed if the tensors + * have different packed dimensions. + * + * If the tensor's packed dim matches the other config's inner dim, returns + * the same config. Otherwise, returns a transposed config (inner/outer + * swapped). + * + * @param graph The compute graph + * @param tensor The tensor to create a block config for + * @param other The reference block config to base the new config on + * @return A BlockConfig with inner_dim = tensor's packed_dim + */ +BlockConfig create_block_config_from_other( + ComputeGraph& graph, + const ValueRef tensor, + const BlockConfig& other); + +/** + * Picks a global workgroup size for block-based dispatching using a linear + * (1D flattened) dispatch pattern. This is optimized for buffer storage. + * + * This function expects: + * - args.at(0).refs.at(0): Output tensor reference + * - extra_args.at(0): Packed int32_t block configuration cast to ValueRef + * (created via static_cast(BlockConfig::as_packed_int())) + * + * The global workgroup size is computed as: + * - x = total_blocks = num_inner_blocks * num_outer_blocks * num_planes + * - y = 1 + * - z = 1 + * + * @return Global workgroup size as {total_blocks, 1, 1} + */ +utils::uvec3 pick_linear_global_wg_with_block_config( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& extra_args); + +/** + * Picks a global workgroup size for block-based dispatching using a 3D + * extents-style dispatch pattern. This is optimized for texture storage. + * + * This function expects: + * - args.at(0).refs.at(0): Output tensor reference + * - extra_args.at(0): Packed int32_t block configuration cast to ValueRef + * (created via static_cast(BlockConfig::as_packed_int())) + * + * The global workgroup size is computed as a WHCN-based 3D dispatch: + * - x = W threads (divided by 4 if W is inner or outer dim) + * - y = H threads (divided by 4 if H is inner or outer dim) + * - z = C * N threads (C divided by 4 if C is inner or outer dim) + * + * @return Global workgroup size as {x_threads, y_threads, z_threads} + */ +utils::uvec3 pick_extents_global_wg_with_block_config( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const std::vector& args, + const std::vector& extra_args); + +/** + * Picks a local workgroup size for block-based dispatching that is optimized + * for the dispatch pattern in use. + * + * This function expects: + * - extra_args.at(0): Packed int32_t block configuration cast to ValueRef + * (created via static_cast(BlockConfig::as_packed_int())) + * + * For linear dispatch (buffer storage, global_wg = {total_blocks, 1, 1}): + * - Returns {64, 1, 1} + * + * For extents dispatch (texture storage, global_wg = {x, y, z}): + * - Returns an 8x8 square configuration where: + * - Axes corresponding to inner_dim and outer_dim are set to 8 + * - The remaining axis is set to 1 + * - For example: inner_dim=W, outer_dim=H -> {8, 8, 1} + * inner_dim=W, outer_dim=C -> {8, 1, 8} + * + * @return Local workgroup size optimized for the dispatch pattern + */ +utils::uvec3 pick_square_local_wg_with_block_config( + ComputeGraph* graph, + const vkapi::ShaderInfo& shader, + const utils::uvec3& global_workgroup_size, + const std::vector& args, + const std::vector& extra_args); + } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.cpp new file mode 100644 index 00000000000..bca36444725 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.cpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void add_q8ta_quantize_node( + ComputeGraph& graph, + const ValueRef fp_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_int8_output) { + float inv_scale = 1.0f / graph.extract_scalar(input_scale); + int32_t zp = graph.extract_scalar(input_zp); + + // Detect input storage type to select appropriate shader variant + utils::StorageType inp_storage = graph.storage_type_of(fp_input); + + // Build shader name: q8ta_quantize_{buffer|texture3d}_{dtype} + std::string kernel_name = "q8ta_quantize"; + add_storage_type_suffix(kernel_name, inp_storage); + add_dtype_suffix(kernel_name, graph.dtype_of(fp_input)); + + // Pass metadata for both output and input tensors + // Output is always buffer, input can be buffer or texture + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.buffer_meta_ubo(packed_int8_output)); + param_buffers.append(graph.meta_ubo(fp_input)); + + std::vector push_constants = { + PushConstantDataInfo(&inv_scale, sizeof(inv_scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + // Create block config for output tensor: inner_dim = output's packed_dim + const BlockConfig outp_block_config = create_block_config_from_io_packed_dims( + graph, packed_int8_output, fp_input); + + // Create block config for input tensor: based on outp_block_config but with + // inner_dim = input's packed_dim. If input and output have different packed + // dims, the block axes are transposed. + const BlockConfig inp_block_config = + create_block_config_from_other(graph, fp_input, outp_block_config); + + // Cast block config to ValueRef for pick_*_global_wg_with_block_config + // Use inp_block_config since shader uses inp_block_config for indexing + const ValueRef block_config_ref = + static_cast(inp_block_config.as_packed_int()); + + // Choose dispatch function based on FP input storage type: + // - Buffer: use linear dispatch (better performance) + // - Texture: use extents-style 3D dispatch (better performance) + auto pick_global_wg_size = (inp_storage == utils::kBuffer) + ? pick_linear_global_wg_with_block_config + : pick_extents_global_wg_with_block_config; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_global_wg_size, + pick_square_local_wg_with_block_config, + // Inputs and Outputs + {{packed_int8_output, vkapi::kWrite}, {fp_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {graph.hashed_layout_of(fp_input), + graph.hashed_layout_of(packed_int8_output), + inp_block_config.as_packed_int(), + outp_block_config.as_packed_int()}, + // Resize args + {block_config_ref})); +} + +void add_q8ta_dequantize_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef fp_output) { + float scale = graph.extract_scalar(output_scale); + int32_t zp = graph.extract_scalar(output_zp); + + // Detect output storage type to select appropriate shader variant + utils::StorageType outp_storage = graph.storage_type_of(fp_output); + + // Build shader name: q8ta_dequantize_{buffer|texture3d}_{dtype} + std::string kernel_name = "q8ta_dequantize"; + add_storage_type_suffix(kernel_name, outp_storage); + add_dtype_suffix(kernel_name, graph.dtype_of(fp_output)); + + // Pass metadata for both output and input tensors + // Output can be buffer or texture, input is always buffer + vkapi::ParamsBindList param_buffers; + param_buffers.append(graph.meta_ubo(fp_output)); + param_buffers.append(graph.buffer_meta_ubo(packed_int8_input)); + + std::vector push_constants = { + PushConstantDataInfo(&scale, sizeof(scale)), + PushConstantDataInfo(&zp, sizeof(zp)), + }; + + // Create block config for output tensor: inner_dim = output's packed_dim + const BlockConfig outp_block_config = create_block_config_from_io_packed_dims( + graph, fp_output, packed_int8_input); + + // Create block config for input tensor: based on outp_block_config but with + // inner_dim = input's packed_dim. If input and output have different packed + // dims, the block axes are transposed. + const BlockConfig inp_block_config = create_block_config_from_other( + graph, packed_int8_input, outp_block_config); + + // Cast block config to ValueRef for pick_*_global_wg_with_block_config + // Use inp_block_config since shader uses inp_block_config for indexing + const ValueRef block_config_ref = + static_cast(inp_block_config.as_packed_int()); + + // Choose dispatch function based on FP output storage type: + // - Buffer: use linear dispatch (better performance) + // - Texture: use extents-style 3D dispatch (better performance) + auto pick_global_wg_size = (outp_storage == utils::kBuffer) + ? pick_linear_global_wg_with_block_config + : pick_extents_global_wg_with_block_config; + + graph.execute_nodes().emplace_back(new DynamicDispatchNode( + graph, + VK_KERNEL_FROM_STR(kernel_name), + pick_global_wg_size, + pick_square_local_wg_with_block_config, + // Inputs and Outputs + {{fp_output, vkapi::kWrite}, {packed_int8_input, vkapi::kRead}}, + // Shader params buffers + param_buffers, + // Push Constants + push_constants, + // Specialization Constants + {graph.hashed_layout_of(fp_output), + graph.hashed_layout_of(packed_int8_input), + outp_block_config.as_packed_int(), + inp_block_config.as_packed_int()}, + // Resize args + {block_config_ref})); +} + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.h b/backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.h new file mode 100644 index 00000000000..e2c6a2c51c3 --- /dev/null +++ b/backends/vulkan/runtime/graph/ops/impl/Q8taQuantizeDequantize.h @@ -0,0 +1,33 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include + +namespace vkcompute { + +// +// Generic Quantize, Dequantize (memory layout agnostic) +// + +void add_q8ta_quantize_node( + ComputeGraph& graph, + const ValueRef fp_input, + const ValueRef input_scale, + const ValueRef input_zp, + const ValueRef packed_int8_output); + +void add_q8ta_dequantize_node( + ComputeGraph& graph, + const ValueRef packed_int8_input, + const ValueRef output_scale, + const ValueRef output_zp, + const ValueRef fp_output); + +} // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp index 8ebbf6dcb99..e02a42f60e1 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizeDequantize.cpp @@ -9,9 +9,12 @@ #include #include +#include #include #include +#include + namespace vkcompute { // @@ -383,11 +386,10 @@ void quantize_per_tensor_impl( const ValueRef int8_output = args[last_arg_idx]; - VK_CHECK_COND( - graph.estimate_memory_layout_of(int8_output) == utils::kPackedInt8_4W4C); + VK_CHECK_COND(graph.dtype_of(int8_output) == vkapi::kInt8x4); - add_quantize_and_pack_4w4c_node( - graph, fp_input, scale, zero_point, int8_output); + // Use unified block-based dispatch for all layouts + add_q8ta_quantize_node(graph, fp_input, scale, zero_point, int8_output); } void dequantize_per_tensor_impl( @@ -409,34 +411,10 @@ void dequantize_per_tensor_impl( const ValueRef fp_output = args[last_arg_idx]; - VK_CHECK_COND( - graph.estimate_memory_layout_of(int8_input) == utils::kPackedInt8_4W4C); - - add_unpack_4w4c_and_dequantize_node( - graph, int8_input, scale, zero_point, fp_output); -} + VK_CHECK_COND(graph.dtype_of(int8_input) == vkapi::kInt8x4); -void qdq8ta_conv2d_input( - ComputeGraph& graph, - const std::vector& args) { - int32_t idx = 0; - const ValueRef fp_input = args.at(idx++); - const ValueRef scale = args.at(idx++); - const ValueRef zero_point = args.at(idx++); - const ValueRef fp_output = args.at(idx++); - - TmpTensor packed_int8_input( - &graph, - graph.sizes_of(fp_input), - vkapi::kInt8x4, - utils::kBuffer, - utils::kPackedInt8_4W4C); - - add_quantize_and_pack_4w4c_node( - graph, fp_input, scale, zero_point, packed_int8_input); - - add_unpack_4w4c_and_dequantize_node( - graph, packed_int8_input, scale, zero_point, fp_output); + // Use unified block-based dispatch for all layouts + add_q8ta_dequantize_node(graph, int8_input, scale, zero_point, fp_output); } REGISTER_OPERATORS { @@ -446,7 +424,6 @@ REGISTER_OPERATORS { VK_REGISTER_OP( quantized_decomposed.dequantize_per_tensor.default, dequantize_per_tensor_impl); - VK_REGISTER_OP(etvk.qdq8ta_conv2d_input.default, qdq8ta_conv2d_input); } } // namespace vkcompute diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp index ff00215efbc..2adb32d8c77 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedConvolution.cpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include #include #include @@ -1579,7 +1579,7 @@ void conv2d_q8ta_q8csw_q8to_test( io_storage_type, utils::kPackedInt8_4W4C); - add_quantize_and_pack_4w4c_node( + add_q8ta_quantize_node( graph, fp_input, input_scale, input_zp, packed_int8_input); std::vector conv2d_args = { @@ -1601,7 +1601,7 @@ void conv2d_q8ta_q8csw_q8to_test( conv2d_q8ta_q8csw_q8to(graph, conv2d_args); - add_unpack_4w4c_and_dequantize_node( + add_q8ta_dequantize_node( graph, packed_int8_output, output_scale, output_zp, fp_output); } diff --git a/backends/vulkan/test/custom_ops/CMakeLists.txt b/backends/vulkan/test/custom_ops/CMakeLists.txt index 6db814815fb..781d69c10fe 100644 --- a/backends/vulkan/test/custom_ops/CMakeLists.txt +++ b/backends/vulkan/test/custom_ops/CMakeLists.txt @@ -97,7 +97,7 @@ if(TARGET vulkan_backend) add_operator_prototype(q8csw_conv2d) add_operator_prototype(q4gsw_linear) add_operator_prototype(choose_qparams_per_row) - add_operator_prototype(qdq8ta_conv2d_activations) + add_operator_prototype(test_q8ta_qdq) add_operator_prototype(q8ta_q8csw_q8to_conv2d) add_operator_prototype(q8ta_q8csw_q8to_conv2d_dw) add_operator_prototype(q8ta_q8ta_q8to_add) diff --git a/backends/vulkan/test/custom_ops/impl/TestQ8taQuantizeDequantize.cpp b/backends/vulkan/test/custom_ops/impl/TestQ8taQuantizeDequantize.cpp new file mode 100644 index 00000000000..24f19b4d309 --- /dev/null +++ b/backends/vulkan/test/custom_ops/impl/TestQ8taQuantizeDequantize.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include +#include + +namespace vkcompute { + +void q_dq_8bit_per_tensor( + ComputeGraph& graph, + const std::vector& args) { + int32_t idx = 0; + const ValueRef fp_input = args.at(idx++); + const ValueRef scale = args.at(idx++); + const ValueRef zero_point = args.at(idx++); + const ValueRef layout_int = args.at(idx++); + const ValueRef impl_selector_str = args.at(idx++); + const ValueRef fp_output = args.at(idx++); + + // Extract the layout parameter and cast to GPUMemoryLayout + int32_t layout_value = graph.extract_scalar(layout_int); + utils::GPUMemoryLayout layout = + static_cast(layout_value); + + // Extract the impl_selector string + std::string impl_selector = graph.extract_string(impl_selector_str); + + // Use legacy 4W4C implementation if requested and layout matches + if (impl_selector == "legacy_4w4c" && layout == utils::kPackedInt8_4W4C) { + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + utils::kPackedInt8_4W4C); + + add_quantize_and_pack_4w4c_node( + graph, fp_input, scale, zero_point, packed_int8_input); + + add_unpack_4w4c_and_dequantize_node( + graph, packed_int8_input, scale, zero_point, fp_output); + } else { + // Create temporary tensor with the specified layout + TmpTensor packed_int8_input( + &graph, + graph.sizes_of(fp_input), + vkapi::kInt8x4, + utils::kBuffer, + layout); + + // Use unified block-based dispatch + add_q8ta_quantize_node( + graph, fp_input, scale, zero_point, packed_int8_input); + + add_q8ta_dequantize_node( + graph, packed_int8_input, scale, zero_point, fp_output); + } +} + +REGISTER_OPERATORS { + VK_REGISTER_OP(test_etvk.q_dq_8bit_per_tensor.default, q_dq_8bit_per_tensor); +} + +} // namespace vkcompute diff --git a/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations.cpp b/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations.cpp deleted file mode 100644 index b8b33f30623..00000000000 --- a/backends/vulkan/test/custom_ops/qdq8ta_conv2d_activations.cpp +++ /dev/null @@ -1,255 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD-style license found in the -// LICENSE file in the root directory of this source tree. - -#include -#include -#include -#include -#include -#include -#include "utils.h" - -#include - -using namespace executorch::vulkan::prototyping; -using namespace vkcompute; - -static constexpr int64_t kRefDimSizeLimit = 512; - -// QDQ8TA Conv2D configuration struct for 4D tensor quantize-dequantize testing -struct QDQ8TAConv2DConfig { - int64_t batch_size; // N dimension - int64_t in_channels; // C dimension - int64_t height; // H dimension - int64_t width; // W dimension - std::string test_case_name = "placeholder"; - std::string op_name = "qdq8ta_conv2d_input"; -}; - -// Utility function to create a test case from a QDQ8TAConv2DConfig -TestCase create_test_case_from_config( - const QDQ8TAConv2DConfig& config, - utils::StorageType storage_type, - vkapi::ScalarType input_dtype) { - TestCase test_case; - - // Create a descriptive name for the test case - std::string storage_str = - (storage_type == utils::kTexture3D) ? "Texture3D" : "Buffer"; - std::string dtype_str = (input_dtype == vkapi::kFloat) ? "Float" : "Half"; - - std::string test_name = - config.test_case_name + "_" + storage_str + "_" + dtype_str; - test_case.set_name(test_name); - - // Set the operator name for the test case - std::string operator_name = "etvk." + config.op_name + ".default"; - test_case.set_operator_name(operator_name); - - // Input tensor (float) - [N, C, H, W] - std::vector input_size = { - config.batch_size, config.in_channels, config.height, config.width}; - ValueSpec input_tensor( - input_size, - input_dtype, - storage_type, - utils::kChannelsPacked, // Use channels packed for conv2d tensors - DataGenType::RANDOM); - - float scale_val = 0.007112; - ValueSpec scale(scale_val); - - // Generate random zero point within quantization range - int32_t zero_point_val = -2; - ValueSpec zero_point(zero_point_val); - - // Output tensor (float) - same shape as input [N, C, H, W] - ValueSpec output_tensor( - input_size, - input_dtype, - storage_type, - utils::kChannelsPacked, - DataGenType::ZEROS); - - // Add all specs to test case - test_case.add_input_spec(input_tensor); - test_case.add_input_spec(scale); - test_case.add_input_spec(zero_point); - test_case.add_output_spec(output_tensor); - - test_case.set_abs_tolerance(scale_val + 1e-4); - - // Use layout-only filter for this test since quantize/dequantize ARE the - // operations being tested, not overhead - test_case.set_shader_filter(kLayoutOnlyShaderFilter); - - return test_case; -} - -// Generate easy test cases for qdq8ta_conv2d operation (for debugging) -std::vector generate_qdq8ta_conv2d_easy_cases() { - std::vector test_cases; - - // Single simple configuration for debugging - QDQ8TAConv2DConfig config = { - 1, // batch_size - 3, // in_channels - 4, // height - 4, // width - "simple", // test_case_name - }; - - // Test with both storage types - std::vector storage_types = {utils::kTexture3D}; - std::vector float_types = {vkapi::kFloat}; - - // Generate test cases for each combination - for (const auto& storage_type : storage_types) { - for (const auto& input_dtype : float_types) { - test_cases.push_back( - create_test_case_from_config(config, storage_type, input_dtype)); - } - } - - return test_cases; -} - -// Generate test cases for qdq8ta_conv2d operation -std::vector generate_qdq8ta_conv2d_test_cases() { - std::vector test_cases; - - std::vector configs = { - // Small test cases for correctness - {1, 3, 16, 16}, - {1, 8, 32, 32}, - {1, 16, 24, 24}, - {1, 32, 12, 12}, - {1, 1, 64, 64}, - {1, 3, 64, 64}, - {1, 4, 16, 16}, - - // Different tensor sizes - {1, 8, 20, 20}, - {1, 16, 14, 14}, - {1, 8, 28, 28}, - - // Odd tensor sizes - {1, 3, 15, 15}, - {1, 13, 31, 31}, - {1, 17, 23, 23}, - - // Performance test cases (larger tensors) - {1, 64, 128, 128}, - {1, 32, 64, 64}, - {1, 128, 56, 56}, - }; - - // Test with different storage types - std::vector storage_types = {utils::kTexture3D}; - - for (auto config : configs) { - std::string prefix = - (config.batch_size < kRefDimSizeLimit && - config.in_channels < kRefDimSizeLimit && - config.height < kRefDimSizeLimit && config.width < kRefDimSizeLimit) - ? "correctness_" - : "performance_"; - std::string generated_test_case_name = prefix + - std::to_string(config.batch_size) + "_" + - std::to_string(config.in_channels) + "_" + - std::to_string(config.height) + "_" + std::to_string(config.width); - - config.test_case_name = generated_test_case_name; - - for (const auto& storage_type : storage_types) { - test_cases.push_back( - create_test_case_from_config(config, storage_type, vkapi::kFloat)); - } - } - - return test_cases; -} - -// Reference implementation for qdq8ta_conv2d operation -void qdq8ta_conv2d_reference_impl(TestCase& test_case) { - int32_t idx = 0; - const ValueSpec& input_spec = test_case.inputs()[idx++]; - const ValueSpec& scale_spec = test_case.inputs()[idx++]; - const ValueSpec& zero_point_spec = test_case.inputs()[idx++]; - - // Extract output specification - ValueSpec& output_spec = test_case.outputs()[0]; - - // Get tensor dimensions - auto input_sizes = input_spec.get_tensor_sizes(); // [N, C, H, W] - int64_t N = input_sizes[0]; - int64_t C = input_sizes[1]; - int64_t H = input_sizes[2]; - int64_t W = input_sizes[3]; - - // Skip for large tensors since computation time will be extremely slow - if (N > kRefDimSizeLimit || C > kRefDimSizeLimit || H > kRefDimSizeLimit || - W > kRefDimSizeLimit) { - throw std::invalid_argument( - "One or more dimensions (N, C, H, W) exceed the allowed limit for reference implementation."); - } - - if (input_spec.dtype != vkapi::kFloat) { - throw std::invalid_argument("Unsupported dtype"); - } - - // Get raw data pointers - auto& input_data = input_spec.get_float_data(); - - // Extract the randomized scale and zero point values (following - // q8csw_conv2d.cpp pattern) - float scale = scale_spec.get_float_value(); - int32_t zero_point = zero_point_spec.get_int_value(); - int32_t quant_min = -128; - int32_t quant_max = 127; - - // Prepare output data - auto& ref_data = output_spec.get_ref_float_data(); - int64_t num_elements = N * C * H * W; - ref_data.resize(num_elements); - - // Perform quantize-dequantize operation on each element - for (int64_t i = 0; i < num_elements; ++i) { - float input_val = input_data[i]; - - // Quantize: quantized = round(input / scale + zero_point) - float quantized_float = std::round(input_val / scale) + zero_point; - - // Clamp to quantization range - quantized_float = std::max(quantized_float, static_cast(quant_min)); - quantized_float = std::min(quantized_float, static_cast(quant_max)); - - int32_t quantized_int = static_cast(quantized_float); - - // Dequantize: output = (quantized - zero_point) * scale - float dequantized = (quantized_int - zero_point) * scale; - - ref_data[i] = dequantized; - } -} - -int main(int argc, char* argv[]) { - set_debugging(false); - set_print_output(false); - set_print_latencies(false); - set_use_gpu_timestamps(true); - - print_performance_header(); - std::cout << "QDQ8TA Conv2D Operation Prototyping Framework" << std::endl; - print_separator(); - - ReferenceComputeFunc ref_fn = qdq8ta_conv2d_reference_impl; - - auto results = execute_test_cases( - generate_qdq8ta_conv2d_test_cases, "QDQ8TAConv2D", 0, 1, ref_fn); - - return 0; -} diff --git a/backends/vulkan/test/custom_ops/targets.bzl b/backends/vulkan/test/custom_ops/targets.bzl index 6f35444570b..6633365838c 100644 --- a/backends/vulkan/test/custom_ops/targets.bzl +++ b/backends/vulkan/test/custom_ops/targets.bzl @@ -91,7 +91,7 @@ def define_common_targets(is_fbcode = False): define_custom_op_test_binary("q8csw_conv2d") define_custom_op_test_binary("choose_qparams_per_row") define_custom_op_test_binary("q4gsw_linear") - define_custom_op_test_binary("qdq8ta_conv2d_activations") + define_custom_op_test_binary("test_q8ta_qdq") define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d") define_custom_op_test_binary("q8ta_q8csw_q8to_conv2d_dw") define_custom_op_test_binary("q8ta_q8ta_q8to_add") diff --git a/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp b/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp new file mode 100644 index 00000000000..e0efd6ea85d --- /dev/null +++ b/backends/vulkan/test/custom_ops/test_q8ta_qdq.cpp @@ -0,0 +1,348 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include +#include +#include "utils.h" + +#include + +// #define DEBUG_MODE + +using namespace executorch::vulkan::prototyping; +using namespace vkcompute; + +static constexpr int64_t kRefDimSizeLimit = 512; + +// Configuration struct for tensor quantize-dequantize testing +struct QDQ8BitConfig { + std::vector shape; // Tensor shape (can be any dimensionality) + std::string test_case_name = "placeholder"; + std::string op_name = "q_dq_8bit_per_tensor"; +}; + +// Utility function to create a test case from a QDQ8BitConfig +TestCase create_test_case_from_config( + const QDQ8BitConfig& config, + utils::StorageType storage_type, + vkapi::ScalarType input_dtype, + utils::GPUMemoryLayout fp_memory_layout, + utils::GPUMemoryLayout quantized_memory_layout, + const std::string& impl_selector = "") { + TestCase test_case; + + // Create a descriptive name for the test case + // Format: ACCU/PERF I=N,C,H,W Tex_FP->Buf_Quant + std::string shape_str = shape_string(config.shape); + std::string test_name = config.test_case_name + " I=" + shape_str + " " + + repr_str(storage_type, fp_memory_layout) + "->" + + repr_str(utils::kBuffer, quantized_memory_layout); + if (!impl_selector.empty()) { + test_name += " [" + impl_selector + "]"; + } + test_case.set_name(test_name); + + // Set the operator name for the test case + std::string operator_name = "test_etvk." + config.op_name + ".default"; + test_case.set_operator_name(operator_name); + + // Input tensor (float) - any dimensionality + ValueSpec input_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::RANDOM); + + float scale_val = 0.007112; + ValueSpec scale(scale_val); + + // Generate random zero point within quantization range + int32_t zero_point_val = 0; + ValueSpec zero_point(zero_point_val); + + // GPUMemoryLayout as integer (will be cast in the operator) + int32_t layout_int = static_cast(quantized_memory_layout); + ValueSpec layout_spec(layout_int); + + // impl_selector string + ValueSpec impl_selector_spec = ValueSpec::make_string(impl_selector); + + // Output tensor (float) - same shape as input + ValueSpec output_tensor( + config.shape, + input_dtype, + storage_type, + fp_memory_layout, + DataGenType::ZEROS); + + // Add all specs to test case + test_case.add_input_spec(input_tensor); + test_case.add_input_spec(scale); + test_case.add_input_spec(zero_point); + test_case.add_input_spec(layout_spec); + test_case.add_input_spec(impl_selector_spec); + test_case.add_output_spec(output_tensor); + + test_case.set_abs_tolerance(scale_val + 1e-4); + + // Use layout-only filter for this test since quantize/dequantize ARE the + // operations being tested, not overhead + test_case.set_shader_filter(kLayoutOnlyShaderFilter); + + return test_case; +} + +// Generate easy test cases for q_dq_8bit operation (for debugging) +std::vector generate_q_dq_8bit_easy_cases() { + std::vector test_cases; + + // Single simple configuration for debugging + QDQ8BitConfig config = { + {1, 16, 16, 16}, // shape: [N, C, H, W] + "ACCU", // test_case_name + }; + + // FP memory layouts to test + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + // Quantized memory layouts to test + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + std::vector storage_types = {utils::kBuffer}; + std::vector float_types = {vkapi::kFloat}; + + // Generate test cases for each combination + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + for (const auto& input_dtype : float_types) { + test_cases.push_back(create_test_case_from_config( + config, storage_type, input_dtype, fp_layout, quant_layout)); + // For 4W4C layout, also test with legacy implementation + if (quant_layout == utils::kPackedInt8_4W4C) { + test_cases.push_back(create_test_case_from_config( + config, + storage_type, + input_dtype, + fp_layout, + quant_layout, + /*impl_selector=*/"legacy_4w4c")); + } + } + } + } + } + + return test_cases; +} + +// Generate test cases for q_dq_8bit operation +std::vector generate_q_dq_8bit_test_cases() { + std::vector test_cases; + + // Shapes to test (no layout specified - will be combined with all layouts) + std::vector> shapes = { + // Small test cases for correctness + {1, 3, 16, 16}, + {1, 8, 32, 32}, + {1, 16, 24, 24}, + {1, 32, 12, 12}, + {1, 1, 64, 64}, + {1, 3, 64, 64}, + {1, 4, 16, 16}, + + // Different tensor sizes + {1, 8, 20, 20}, + {1, 16, 14, 14}, + {1, 8, 28, 28}, + + // Odd tensor sizes + {1, 3, 15, 15}, + {1, 13, 31, 31}, + {1, 17, 23, 23}, + + // Performance test cases (larger tensors) + {1, 64, 128, 128}, + {1, 32, 64, 64}, + {1, 128, 56, 56}, + }; + + // FP memory layouts to test + std::vector fp_layouts = { + utils::kWidthPacked, + utils::kChannelsPacked, + }; + + // Quantized memory layouts to test + std::vector quant_layouts = { + utils::kPackedInt8_4W, + utils::kPackedInt8_4C, + utils::kPackedInt8_4W4C, + utils::kPackedInt8_4H4W, + utils::kPackedInt8_4C1W, + }; + + // Test with buffer storage only - the unified block-based shaders only + // support buffer-backed floating-point tensors. Texture storage is tested + // separately by qdq8ta_conv2d_activations which uses the layout-specific + // shaders. + std::vector storage_types = { + utils::kBuffer, utils::kTexture3D}; + + // Generate all combinations + for (const auto& shape : shapes) { + // Generate test case name prefix from shape dimensions + std::string prefix = "ACCU"; + for (const auto& dim : shape) { + if (dim > kRefDimSizeLimit) { + prefix = "PERF"; + break; + } + } + + for (const auto& fp_layout : fp_layouts) { + for (const auto& quant_layout : quant_layouts) { + for (const auto& storage_type : storage_types) { + QDQ8BitConfig config; + config.shape = shape; + config.test_case_name = prefix; + + test_cases.push_back(create_test_case_from_config( + config, storage_type, vkapi::kFloat, fp_layout, quant_layout)); + // For 4W4C layout, also test with legacy implementation + if (fp_layout == utils::kChannelsPacked && + quant_layout == utils::kPackedInt8_4W4C) { + test_cases.push_back(create_test_case_from_config( + config, + storage_type, + vkapi::kFloat, + fp_layout, + quant_layout, + /*impl_selector=*/"legacy_4w4c")); + } + } + } + } + } + + return test_cases; +} + +// Reference implementation for q_dq_8bit operation +void q_dq_8bit_reference_impl(TestCase& test_case) { + int32_t idx = 0; + const ValueSpec& input_spec = test_case.inputs()[idx++]; + const ValueSpec& scale_spec = test_case.inputs()[idx++]; + const ValueSpec& zero_point_spec = test_case.inputs()[idx++]; + const ValueSpec& layout_spec = test_case.inputs()[idx++]; + (void)layout_spec; // Not used in reference implementation + + // Extract output specification + ValueSpec& output_spec = test_case.outputs()[0]; + + // Get tensor dimensions (arbitrary dimensionality) + auto input_sizes = input_spec.get_tensor_sizes(); + + // Calculate total number of elements + int64_t num_elements = 1; + for (const auto& dim : input_sizes) { + num_elements *= dim; + } + + // Skip for large tensors since computation time will be extremely slow + for (const auto& dim : input_sizes) { + if (dim > kRefDimSizeLimit) { + throw std::invalid_argument( + "One or more dimensions exceed the allowed limit for reference implementation."); + } + } + + if (input_spec.dtype != vkapi::kFloat) { + throw std::invalid_argument("Unsupported dtype"); + } + + // Get raw data pointers + auto& input_data = input_spec.get_float_data(); + + // Extract the randomized scale and zero point values + float scale = scale_spec.get_float_value(); + int32_t zero_point = zero_point_spec.get_int_value(); + int32_t quant_min = -128; + int32_t quant_max = 127; + + // Prepare output data + auto& ref_data = output_spec.get_ref_float_data(); + ref_data.resize(num_elements); + + // Perform quantize-dequantize operation on each element + for (int64_t i = 0; i < num_elements; ++i) { + float input_val = input_data[i]; + + // Quantize: quantized = round(input / scale + zero_point) + float quantized_float = std::round(input_val / scale) + zero_point; + + // Clamp to quantization range + quantized_float = std::max(quantized_float, static_cast(quant_min)); + quantized_float = std::min(quantized_float, static_cast(quant_max)); + + int32_t quantized_int = static_cast(quantized_float); + + // Dequantize: output = (quantized - zero_point) * scale + float dequantized = (quantized_int - zero_point) * scale; + + ref_data[i] = dequantized; + } +} + +int main(int argc, char* argv[]) { + set_debugging(false); + set_print_output(false); +#ifdef DEBUG_MODE + set_print_latencies(false); +#else + set_print_latencies(false); +#endif + set_use_gpu_timestamps(true); + + print_performance_header(); + std::cout << "Q/DQ 8-bit Per Tensor Operation Prototyping Framework" + << std::endl; + print_separator(); + + ReferenceComputeFunc ref_fn = q_dq_8bit_reference_impl; + + auto results = execute_test_cases( +#ifdef DEBUG_MODE + generate_q_dq_8bit_easy_cases, +#else + generate_q_dq_8bit_test_cases, +#endif + "QDQ8Bit", +#ifdef DEBUG_MODE + 0, + 1, +#else + 3, + 10, +#endif + ref_fn); + + return 0; +}