diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 136f969896f5..b1d3ae9cc2d0 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -676,13 +676,12 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { spirv::Value init_value = MakeValue(op->min); PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent); spirv::Value end_value = MakeValue(end); - spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); // loop step spirv::Value step; if (op->HasTrivialStep()) { - step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) - : builder_->UIntImm(loop_var.stype, 1); + step = op->loop_var.dtype().is_int() ? builder_->IntImm(init_value.stype, 1) + : builder_->UIntImm(init_value.stype, 1); } else { step = MakeValue(tvm::cast(end->dtype, *op->step)); } @@ -699,8 +698,9 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { builder_->SetName(merge_label, "for_loop_merge"); builder_->MakeInst(spv::OpBranch, head_label); - // Loop head + // Loop head - Phi must be created AFTER StartLabel so it's in the head block builder_->StartLabel(head_label); + spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); loop_var.SetIncoming(0, init_value, init_label); spirv::Value loop_cond = builder_->LT(loop_var, end_value); uint32_t control = diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index bac66a3aacf7..f912e482761c 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -34,7 +34,7 @@ namespace spirv { IRBuilder::IRBuilder(const SPIRVSupport& support) : spirv_support_(support) {} void IRBuilder::InitHeader() { - ICHECK_EQ(header_.size(), 0U); + TVM_FFI_ICHECK_EQ(header_.size(), 0U); header_.push_back(spv::MagicNumber); // Target SPIR-V version 1.0. Additional functionality will be @@ -126,7 +126,7 @@ SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { type_key = static_cast(dtype.code()); type_key |= static_cast(dtype.bits()) << 8U; if (row * col == 0) { - ICHECK((row == 0) && (col == 0)); + TVM_FFI_ICHECK((row == 0) && (col == 0)); type_key |= static_cast(dtype.lanes()) << 16U; } else { type_key |= static_cast(row) << 32U; @@ -143,7 +143,7 @@ SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) { } SType IRBuilder::GetPointerType(const SType& value_type, spv::StorageClass storage_class) { - ICHECK_NE(storage_class, spv::StorageClassMax); + TVM_FFI_ICHECK_NE(storage_class, spv::StorageClassMax); auto key = std::make_pair(value_type.id, storage_class); auto it = pointer_type_tbl_.find(key); if (it != pointer_type_tbl_.end()) { @@ -178,23 +178,24 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems, } else { ib_.Begin(spv::OpTypeRuntimeArray).AddSeq(arr_type, value_type).Commit(&global_); } - int nbits = value_type.type.bits() * value_type.type.lanes(); - ICHECK_EQ(nbits % 8, 0); - uint32_t nbytes = static_cast(nbits) / 8; - // decorate the array type. - this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes); + if (interface_block) { + int nbits = value_type.type.bits() * value_type.type.lanes(); + TVM_FFI_ICHECK_EQ(nbits % 8, 0); + uint32_t nbytes = static_cast(nbits) / 8; + // Explicit layout is required for descriptor-backed interface blocks. + this->Decorate(spv::OpDecorate, arr_type, spv::DecorationArrayStride, nbytes); + } // declare struct of array SType struct_type; struct_type.id = id_counter_++; struct_type.type = DataType::Handle(); struct_type.element_type_id = value_type.id; ib_.Begin(spv::OpTypeStruct).AddSeq(struct_type, arr_type).Commit(&global_); - // decorate the array type. - ib_.Begin(spv::OpMemberDecorate) - .AddSeq(struct_type, 0, spv::DecorationOffset, 0) - .Commit(&decorate_); if (interface_block) { + ib_.Begin(spv::OpMemberDecorate) + .AddSeq(struct_type, 0, spv::DecorationOffset, 0) + .Commit(&decorate_); // Runtime array are always decorated as Block or BufferBlock // (shader storage buffer) if (spirv_support_.supports_storage_buffer_storage_class) { @@ -214,7 +215,7 @@ SType IRBuilder::GetStructArrayType(const SType& value_type, uint32_t num_elems, } Value IRBuilder::StructArrayAccess(const SType& res_type, Value buffer, Value index) { - ICHECK(buffer.flag == kStructArrayPtr); + TVM_FFI_ICHECK(buffer.flag == kStructArrayPtr); return MakeValue(spv::OpInBoundsAccessChain, res_type, buffer, const_i32_zero_, index); } @@ -233,7 +234,7 @@ Value IRBuilder::FloatImm(const SType& dtype, double value) { uint64_t data = ptr[0]; return GetConst_(dtype, &data); } else { - ICHECK_EQ(dtype.type.bits(), 16); + TVM_FFI_ICHECK_EQ(dtype.type.bits(), 16); float fvalue = static_cast(value); uint32_t* ptr = reinterpret_cast(&fvalue); uint64_t data = ptr[0]; @@ -283,13 +284,13 @@ Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, .Commit(&decorate_); DataType t = value_types[i].type; uint32_t nbits = t.bits() * t.lanes(); - ICHECK_EQ(nbits % 8, 0); + TVM_FFI_ICHECK_EQ(nbits % 8, 0); uint32_t bytes = (nbits / 8); if (t.bits() == 32) { // In our Vulkan runtime, each scalar argument always occupies 64 bit. offset += bytes * 2; } else { - ICHECK_EQ(t.bits(), 64); + TVM_FFI_ICHECK_EQ(t.bits(), 64); offset += bytes; } } @@ -302,7 +303,7 @@ Value IRBuilder::DeclareStorageVariable(const std::vector& value_types, } Value IRBuilder::DeclarePushConstant(const std::vector& value_types) { - ICHECK_EQ(push_const_.id, 0); + TVM_FFI_ICHECK_EQ(push_const_.id, 0); return DeclareStorageVariable(value_types, spv::StorageClassPushConstant, kPushConstantPtr); } @@ -335,7 +336,7 @@ Value IRBuilder::GetUniform(Value ptr_push_const, const SType& v_type, uint32_t Value IRBuilder::NewFunction() { return NewValue(t_void_func_, kFunction); } void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) { - ICHECK_EQ(func.flag, kFunction); + TVM_FFI_ICHECK_EQ(func.flag, kFunction); ib_.Begin(spv::OpEntryPoint).AddSeq(spv::ExecutionModelGLCompute, func, name); for (auto& it : built_in_tbl_) { ib_.Add(it.second); @@ -344,7 +345,7 @@ void IRBuilder::CommitKernelFunction(const Value& func, const std::string& name) } void IRBuilder::StartFunction(const Value& func) { - ICHECK_EQ(func.flag, kFunction); + TVM_FFI_ICHECK_EQ(func.flag, kFunction); // add function declaration to the header. ib_.Begin(spv::OpFunction).AddSeq(t_void_, func, 0, t_void_func_).Commit(&func_header_); @@ -354,7 +355,7 @@ void IRBuilder::StartFunction(const Value& func) { } void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) { - ICHECK_EQ(func.flag, kFunction); + TVM_FFI_ICHECK_EQ(func.flag, kFunction); ib_.Begin(spv::OpExecutionMode) .AddSeq(func, spv::ExecutionModeLocalSize, local_size[0], local_size[1], local_size[2]) .Commit(&exec_mode_); @@ -362,7 +363,7 @@ void IRBuilder::SetLocalSize(const Value& func, uint32_t local_size[3]) { Value IRBuilder::Allocate(const SType& value_type, uint32_t num_elems, spv::StorageClass storage_class) { - ICHECK_NE(num_elems, 0U); + TVM_FFI_ICHECK_NE(num_elems, 0U); SType sarr_type = GetStructArrayType(value_type, num_elems, false); SType ptr_type = GetPointerType(sarr_type, storage_class); Value val = NewValue(ptr_type, kStructArrayPtr); @@ -403,7 +404,7 @@ Value IRBuilder::GetBuiltInValue(spv::BuiltIn built_in, uint32_t index, const st break; default: - LOG(FATAL) << "No data type defined for SPIR-V Built-In " << built_in; + TVM_FFI_THROW(InternalError) << "No data type defined for SPIR-V Built-In " << built_in; } // Look up the decorated array value at global scope. If it doesn't @@ -465,7 +466,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const uint64_t* pvalue) { if (it != const_tbl_.end()) { return it->second; } - ICHECK_LE(dtype.type.bits(), 64); + TVM_FFI_ICHECK_LE(dtype.type.bits(), 64); Value ret = NewValue(dtype, kConstant); if (dtype.type == DataType::Bool()) { // bool types. @@ -510,7 +511,7 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) } else if (dtype.is_float()) { ib_.Begin(spv::OpTypeFloat).AddSeq(t, dtype.bits()).Commit(&global_); } else { - LOG(FATAL) << "declare type do not support handle"; + TVM_FFI_THROW(InternalError) << "declare type do not support handle"; } return t; } else { @@ -520,7 +521,7 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) SType base_type = GetSType(dtype.element_of()); if (row * col == 0) { - ICHECK((row == 0) && (col == 0)); + TVM_FFI_ICHECK((row == 0) && (col == 0)); ib_.Begin(spv::OpTypeVector).AddSeq(t, base_type, dtype.lanes()).Commit(&global_); } else { Value v_row = GetSpecConst(GetSType(DataType::UInt(32)), row); @@ -538,21 +539,21 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { // Declare appropriate capabilities for int/float types if (dtype.is_int() || dtype.is_uint()) { if (dtype.bits() == 8) { - ICHECK(spirv_support_.supports_int8) + TVM_FFI_ICHECK(spirv_support_.supports_int8) << "Vulkan target does not support Int8 capability. " << "If your device supports 8-bit int operations, " << "please either add -supports_int8=1 to the target, " << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityInt8); } else if (dtype.bits() == 16) { - ICHECK(spirv_support_.supports_int16) + TVM_FFI_ICHECK(spirv_support_.supports_int16) << "Vulkan target does not support Int16 capability. " << "If your device supports 16-bit int operations, " << "please either add -supports_int16=1 to the target, " << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityInt16); } else if (dtype.bits() == 64) { - ICHECK(spirv_support_.supports_int64) + TVM_FFI_ICHECK(spirv_support_.supports_int64) << "Vulkan target does not support Int64 capability. " << "If your device supports 64-bit int operations, " << "please either add -supports_int64=1 to the target, " @@ -562,14 +563,14 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { } else if (dtype.is_float()) { if (dtype.bits() == 16) { - ICHECK(spirv_support_.supports_float16) + TVM_FFI_ICHECK(spirv_support_.supports_float16) << "Vulkan target does not support Float16 capability. " << "If your device supports 16-bit float operations, " << "please either add -supports_float16=1 to the target, " << "or query all device parameters by adding -from_device=0."; capabilities_used_.insert(spv::CapabilityFloat16); } else if (dtype.bits() == 64) { - ICHECK(spirv_support_.supports_float64) + TVM_FFI_ICHECK(spirv_support_.supports_float64) << "Vulkan target does not support Float64 capability. " << "If your device supports 64-bit float operations, " << "please either add -supports_float64=1 to the target, " @@ -584,7 +585,7 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { // Int8 prevents use of an 8-bit loop iterator on a device that // supports Int8 but doesn't support 8-bit buffer access. if (dtype.bits() == 8 && !dtype.is_bool()) { - ICHECK(spirv_support_.supports_storage_buffer_8bit_access) + TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_8bit_access) << "Vulkan target does not support StorageBuffer8BitAccess. " << "If your device supports 8-bit buffer access, " << "please either add -supports_8bit_buffer=1 to the target, " @@ -592,14 +593,14 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) { capabilities_used_.insert(spv::CapabilityStorageBuffer8BitAccess); extensions_used_.insert("SPV_KHR_8bit_storage"); - ICHECK(spirv_support_.supports_storage_buffer_storage_class) + TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_storage_class) << "Illegal Vulkan target description. " << "Vulkan spec requires extension VK_KHR_storage_buffer_storage_class " << "if VK_KHR_8bit_storage is supported. " << "Please either add -supports_storage_buffer_storage_class=1 to the target, " << "or query all device parameters by adding -from_device=0."; } else if (dtype.bits() == 16) { - ICHECK(spirv_support_.supports_storage_buffer_16bit_access) + TVM_FFI_ICHECK(spirv_support_.supports_storage_buffer_16bit_access) << "Vulkan target does not support StorageBuffer16BitAccess. " << "If your device supports 16-bit buffer access, " << "please either add -supports_16bit_buffer=1 to the target, " @@ -625,7 +626,7 @@ PhiValue IRBuilder::MakePhi(const SType& out_type, uint32_t num_incoming) { phi.stype = out_type; phi.flag = kNormal; phi.instr = ib_.Commit(&function_); - ICHECK_EQ(phi.instr.WordCount(), 2 * num_incoming + 3); + TVM_FFI_ICHECK_EQ(phi.instr.WordCount(), 2 * num_incoming + 3); return phi; } @@ -643,11 +644,11 @@ Value IRBuilder::CallGLSL450(const SType& ret_type, uint32_t inst_id, Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vector& args, const DataType& dtype) { if (args.size() != 3) { - LOG(FATAL) << "Unresolved arguments in SPIRV_KHR_integer_dot_product"; + TVM_FFI_THROW(InternalError) << "Unresolved arguments in SPIRV_KHR_integer_dot_product"; } Value val = NewValue(ret_type, kNormal); #ifdef TVM_SPIRV_KHR_INTEGER_DOT_PRODUCT - ICHECK(spirv_support_.supports_integer_dot_product) + TVM_FFI_ICHECK(spirv_support_.supports_integer_dot_product) << "Vulkan target does not support integer dot product capability. " << "If your device supports integer dot product operations, " << "please either add -mattr=+dotprod to the target, " @@ -657,10 +658,11 @@ Value IRBuilder::CallKHRIntegerDotProduct(const SType& ret_type, const std::vect } else if (dtype.is_uint()) { ib_.Begin(spv::OpUDotAccSatKHR).AddSeq(ret_type, val); } else { - LOG(FATAL) << "Unsupported type"; + TVM_FFI_THROW(InternalError) << "Unsupported type"; } #else - LOG(FATAL) << "Please turn on USE_SPIRV_KHR_INTEGER_DOT_PRODUCT in config.cmake"; + TVM_FFI_THROW(InternalError) + << "Please turn on USE_SPIRV_KHR_INTEGER_DOT_PRODUCT in config.cmake"; #endif for (const Value& v : args) { @@ -675,7 +677,7 @@ Value IRBuilder::Concat(const std::vector& vec) { DataType etype = vec[0].stype.type; int lanes = etype.lanes(); for (size_t i = 1; i < vec.size(); ++i) { - ICHECK_EQ(etype, vec[i].stype.type.element_of()) + TVM_FFI_ICHECK_EQ(etype, vec[i].stype.type.element_of()) << "Cannot concat vector of different element type"; lanes += vec[i].stype.type.lanes(); is_const = is_const && (vec[i].flag == kConstant); @@ -700,11 +702,11 @@ Value IRBuilder::Concat(const std::vector& vec) { } Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { - ICHECK_NE(value.stype.id, 0U); + TVM_FFI_ICHECK_NE(value.stype.id, 0U); if (value.stype.id == dst_type.id) return value; const tvm::DataType& from = value.stype.type; const tvm::DataType& to = dst_type.type; - ICHECK_EQ(from.lanes(), to.lanes()); + TVM_FFI_ICHECK_EQ(from.lanes(), to.lanes()); if (from == DataType::Bool()) { if (to.is_int()) { return Select(value, IntImm(dst_type, 1), IntImm(dst_type, 0)); @@ -714,7 +716,7 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { return MakeValue(spv::OpConvertUToF, dst_type, Select(value, UIntImm(t_uint32_, 1), UIntImm(t_uint32_, 0))); } else { - LOG(FATAL) << "cannot cast from " << from << " to " << to; + TVM_FFI_THROW(InternalError) << "cannot cast from " << from << " to " << to; return Value(); } } else if (to == DataType::Bool()) { @@ -723,7 +725,7 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { } else if (to.is_uint()) { return NE(value, UIntImm(value.stype, 0)); } else { - LOG(FATAL) << "cannot cast from " << from << " to " << to; + TVM_FFI_THROW(InternalError) << "cannot cast from " << from << " to " << to; return Value(); } } else if (from.is_int() && to.is_int()) { @@ -751,7 +753,7 @@ Value IRBuilder::Cast(const SType& dst_type, spirv::Value value) { } else if (from.is_float() && to.is_float()) { return MakeValue(spv::OpFConvert, dst_type, value); } else { - LOG(FATAL) << "do not support type cast from " << from << " to " << to; + TVM_FFI_THROW(InternalError) << "do not support type cast from " << from << " to " << to; return Value(); } } @@ -772,7 +774,7 @@ Value IRBuilder::GetCompositeConst(const SType& ele_stype, const SType& composit } Value IRBuilder::GetSpecConst(const SType& dtype, uint64_t value) { - ICHECK_LE(dtype.type.bits(), 32); + TVM_FFI_ICHECK_LE(dtype.type.bits(), 32); Value ret = NewValue(dtype, kSpecConst); ib_.Begin(spv::OpSpecConstant).AddSeq(dtype, ret); ib_.Add(static_cast(value)); @@ -782,24 +784,24 @@ Value IRBuilder::GetSpecConst(const SType& dtype, uint64_t value) { #define DEFINE_BUILDER_BINARY_USIGN_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ return MakeValue(spv::OpI##_Op, a.stype, a, b); \ } else { \ - ICHECK(a.stype.type.is_float()); \ + TVM_FFI_ICHECK(a.stype.type.is_float()); \ return MakeValue(spv::OpF##_Op, a.stype, a, b); \ } \ } #define DEFINE_BUILDER_BINARY_SIGN_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ if (a.stype.type.is_int()) { \ return MakeValue(spv::OpS##_Op, a.stype, a, b); \ } else if (a.stype.type.is_uint()) { \ return MakeValue(spv::OpU##_Op, a.stype, a, b); \ } else { \ - ICHECK(a.stype.type.is_float()); \ + TVM_FFI_ICHECK(a.stype.type.is_float()); \ return MakeValue(spv::OpF##_Op, a.stype, a, b); \ } \ } @@ -810,28 +812,28 @@ DEFINE_BUILDER_BINARY_USIGN_OP(Mul, Mul); DEFINE_BUILDER_BINARY_SIGN_OP(Div, Div); Value IRBuilder::Mod(Value a, Value b) { - ICHECK_EQ(a.stype.id, b.stype.id); + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); if (a.stype.type.is_int()) { return MakeValue(spv::OpSRem, a.stype, a, b); } else if (a.stype.type.is_uint()) { return MakeValue(spv::OpUMod, a.stype, a, b); } else { - ICHECK(a.stype.type.is_float()); + TVM_FFI_ICHECK(a.stype.type.is_float()); return MakeValue(spv::OpFRem, a.stype, a, b); } } #define DEFINE_BUILDER_CMP_OP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ - ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ if (a.stype.type.is_int()) { \ return MakeValue(spv::OpS##_Op, bool_type, a, b); \ } else if (a.stype.type.is_uint()) { \ return MakeValue(spv::OpU##_Op, bool_type, a, b); \ } else { \ - ICHECK(a.stype.type.is_float()); \ + TVM_FFI_ICHECK(a.stype.type.is_float()); \ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ } \ } @@ -843,13 +845,13 @@ DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual); #define DEFINE_BUILDER_CMP_UOP(_OpName, _Op) \ Value IRBuilder::_OpName(Value a, Value b) { \ - ICHECK_EQ(a.stype.id, b.stype.id); \ - ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); \ + TVM_FFI_ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes()); \ const auto& bool_type = this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \ if (a.stype.type.is_int() || a.stype.type.is_uint()) { \ return MakeValue(spv::OpI##_Op, bool_type, a, b); \ } else { \ - ICHECK(a.stype.type.is_float()); \ + TVM_FFI_ICHECK(a.stype.type.is_float()); \ return MakeValue(spv::OpFOrd##_Op, bool_type, a, b); \ } \ } @@ -858,8 +860,8 @@ DEFINE_BUILDER_CMP_UOP(EQ, Equal); DEFINE_BUILDER_CMP_UOP(NE, NotEqual); Value IRBuilder::Select(Value cond, Value a, Value b) { - ICHECK_EQ(a.stype.id, b.stype.id); - ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool()); + TVM_FFI_ICHECK_EQ(a.stype.id, b.stype.id); + TVM_FFI_ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool()); return MakeValue(spv::OpSelect, a.stype, cond, a, b); } diff --git a/tests/python/codegen/test_target_codegen_vulkan.py b/tests/python/codegen/test_target_codegen_vulkan.py index cf7b46692661..46625771b0c1 100644 --- a/tests/python/codegen/test_target_codegen_vulkan.py +++ b/tests/python/codegen/test_target_codegen_vulkan.py @@ -570,6 +570,24 @@ def kernel(): vulkan_codegen(mod, target) +@tvm.testing.requires_vulkan(support_required="compile-only") +def test_codegen_static_shared_memory(): + """The codegen should accept static shared/workgroup allocations.""" + + @I.ir_module + class Module: + @T.prim_func + def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), "float32")): + A_shared = T.alloc_buffer((128,), dtype="float32", scope="shared") + + for bx in T.thread_binding(1, thread="blockIdx.x"): + for tx in T.thread_binding(128, thread="threadIdx.x"): + A_shared[tx] = A[tx] + B[tx] = A_shared[tx] + + tvm.compile(Module, target="vulkan") + + @tvm.testing.requires_gpu @tvm.testing.requires_vulkan def test_unary():